From c5ffc7374566e28d8532856a97c6debf80f10175 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Wed, 26 Jun 2024 17:25:10 -0700 Subject: [PATCH 01/41] add genai.phi and tests --- Microsoft.ML.sln | 24 +- .../Microsoft.ML.GenAI.Core.csproj | 7 +- .../Module/GenAILinear.cs | 48 ++ .../Module/NewGELUActivation.cs | 24 + .../Pipeline/CasualLMModelOutput.cs | 4 +- .../Pipeline/CausalLMPipeline.cs | 2 +- .../Microsoft.ML.GenAI.Phi.csproj | 22 + .../Module/Phi2Attention.cs | 155 ++++++ .../Module/Phi2DecoderLayer.cs | 62 +++ src/Microsoft.ML.GenAI.Phi/Module/Phi2MLP.cs | 33 ++ .../Module/Phi2Model.cs | 154 ++++++ .../Module/Phi2RotaryEmbedding.cs | 45 ++ .../Module/Phi3Attention.cs | 192 ++++++++ .../Module/Phi3DecoderLayer.cs | 125 +++++ src/Microsoft.ML.GenAI.Phi/Module/Phi3MLP.cs | 46 ++ .../Module/Phi3Model.cs | 134 ++++++ .../Module/Phi3RMSNorm.cs | 55 +++ .../Module/Phi3RotaryEmbedding.cs | 77 +++ .../Module/Phi3SuScaledRotaryEmbedding.cs | 75 +++ src/Microsoft.ML.GenAI.Phi/Phi2/Config.cs | 101 ++++ .../Phi2/Phi2Extension.cs | 15 + .../Phi2/Phi2ForCasualLM.cs | 64 +++ src/Microsoft.ML.GenAI.Phi/Phi3/Config.cs | 113 +++++ .../Phi3/Phi3ForCasualLM.cs | 65 +++ .../Phi3/Phi3Tokenzier.cs | 222 +++++++++ src/Microsoft.ML.GenAI.Phi/Utils.cs | 159 ++++++ src/Microsoft.ML.Tokenizers/Tokenizer.cs | 18 + ...2Test.LoadSafeTensorShapeTest.approved.txt | 453 ++++++++++++++++++ .../Phi2Test.TokenizerTest.approved.txt | 3 + ...Tests.Phi3Medium128KShapeTest.approved.txt | 243 ++++++++++ ...i3Tests.Phi3Medium4KShapeTest.approved.txt | 243 ++++++++++ ...sts.Phi3Mini128KLayerSizeTest.approved.txt | 34 ++ ...i3Tests.Phi3Mini128KShapeTest.approved.txt | 195 ++++++++ ...Phi3Tests.Phi3Mini4KShapeTest.approved.txt | 195 ++++++++ .../Phi3Tests.TokenizerTest.approved.txt | 5 + .../Microsoft.ML.GenAI.Phi.Tests.csproj | 33 ++ test/Microsoft.ML.GenAI.Phi.Tests/Phi2Test.cs | 93 ++++ .../Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs | 127 +++++ 38 files changed, 3660 insertions(+), 5 deletions(-) create mode 100644 src/Microsoft.ML.GenAI.Core/Module/GenAILinear.cs create mode 100644 src/Microsoft.ML.GenAI.Core/Module/NewGELUActivation.cs create mode 100644 src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj create mode 100644 src/Microsoft.ML.GenAI.Phi/Module/Phi2Attention.cs create mode 100644 src/Microsoft.ML.GenAI.Phi/Module/Phi2DecoderLayer.cs create mode 100644 src/Microsoft.ML.GenAI.Phi/Module/Phi2MLP.cs create mode 100644 src/Microsoft.ML.GenAI.Phi/Module/Phi2Model.cs create mode 100644 src/Microsoft.ML.GenAI.Phi/Module/Phi2RotaryEmbedding.cs create mode 100644 src/Microsoft.ML.GenAI.Phi/Module/Phi3Attention.cs create mode 100644 src/Microsoft.ML.GenAI.Phi/Module/Phi3DecoderLayer.cs create mode 100644 src/Microsoft.ML.GenAI.Phi/Module/Phi3MLP.cs create mode 100644 src/Microsoft.ML.GenAI.Phi/Module/Phi3Model.cs create mode 100644 src/Microsoft.ML.GenAI.Phi/Module/Phi3RMSNorm.cs create mode 100644 src/Microsoft.ML.GenAI.Phi/Module/Phi3RotaryEmbedding.cs create mode 100644 src/Microsoft.ML.GenAI.Phi/Module/Phi3SuScaledRotaryEmbedding.cs create mode 100644 src/Microsoft.ML.GenAI.Phi/Phi2/Config.cs create mode 100644 src/Microsoft.ML.GenAI.Phi/Phi2/Phi2Extension.cs create mode 100644 src/Microsoft.ML.GenAI.Phi/Phi2/Phi2ForCasualLM.cs create mode 100644 src/Microsoft.ML.GenAI.Phi/Phi3/Config.cs create mode 100644 src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ForCasualLM.cs create mode 100644 src/Microsoft.ML.GenAI.Phi/Phi3/Phi3Tokenzier.cs create mode 100644 src/Microsoft.ML.GenAI.Phi/Utils.cs create mode 100644 test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi2Test.LoadSafeTensorShapeTest.approved.txt create mode 100644 test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi2Test.TokenizerTest.approved.txt create mode 100644 test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Medium128KShapeTest.approved.txt create mode 100644 test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Medium4KShapeTest.approved.txt create mode 100644 test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Mini128KLayerSizeTest.approved.txt create mode 100644 test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Mini128KShapeTest.approved.txt create mode 100644 test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Mini4KShapeTest.approved.txt create mode 100644 test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.TokenizerTest.approved.txt create mode 100644 test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj create mode 100644 test/Microsoft.ML.GenAI.Phi.Tests/Phi2Test.cs create mode 100644 test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln index 1fa8823763..824f88dd5f 100644 --- a/Microsoft.ML.sln +++ b/Microsoft.ML.sln @@ -176,7 +176,11 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.TorchSharp.Tes EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.TensorFlow.Tests", "test\Microsoft.ML.TensorFlow.Tests\Microsoft.ML.TensorFlow.Tests.csproj", "{763FF013-8309-4680-A769-B54E7BB99612}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.ML.GenAI.Core", "src\Microsoft.ML.GenAI.Core\Microsoft.ML.GenAI.Core.csproj", "{DB2CA055-8ABD-4E3E-8089-5B64C3415E85}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.GenAI.Core", "src\Microsoft.ML.GenAI.Core\Microsoft.ML.GenAI.Core.csproj", "{DB2CA055-8ABD-4E3E-8089-5B64C3415E85}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.GenAI.Phi", "src\Microsoft.ML.GenAI.Phi\Microsoft.ML.GenAI.Phi.csproj", "{694BF884-B2E4-4E1C-9342-0564BAAC4575}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.ML.GenAI.Phi.Tests", "test\Microsoft.ML.GenAI.Phi.Tests\Microsoft.ML.GenAI.Phi.Tests.csproj", "{867FFC34-DFA7-400F-B9BB-85158326CE08}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution @@ -838,6 +842,22 @@ Global {DB2CA055-8ABD-4E3E-8089-5B64C3415E85}.Release|Any CPU.Build.0 = Release|Any CPU {DB2CA055-8ABD-4E3E-8089-5B64C3415E85}.Release|x64.ActiveCfg = Release|Any CPU {DB2CA055-8ABD-4E3E-8089-5B64C3415E85}.Release|x64.Build.0 = Release|Any CPU + {694BF884-B2E4-4E1C-9342-0564BAAC4575}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {694BF884-B2E4-4E1C-9342-0564BAAC4575}.Debug|Any CPU.Build.0 = Debug|Any CPU + {694BF884-B2E4-4E1C-9342-0564BAAC4575}.Debug|x64.ActiveCfg = Debug|Any CPU + {694BF884-B2E4-4E1C-9342-0564BAAC4575}.Debug|x64.Build.0 = Debug|Any CPU + {694BF884-B2E4-4E1C-9342-0564BAAC4575}.Release|Any CPU.ActiveCfg = Release|Any CPU + {694BF884-B2E4-4E1C-9342-0564BAAC4575}.Release|Any CPU.Build.0 = Release|Any CPU + {694BF884-B2E4-4E1C-9342-0564BAAC4575}.Release|x64.ActiveCfg = Release|Any CPU + {694BF884-B2E4-4E1C-9342-0564BAAC4575}.Release|x64.Build.0 = Release|Any CPU + {867FFC34-DFA7-400F-B9BB-85158326CE08}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {867FFC34-DFA7-400F-B9BB-85158326CE08}.Debug|Any CPU.Build.0 = Debug|Any CPU + {867FFC34-DFA7-400F-B9BB-85158326CE08}.Debug|x64.ActiveCfg = Debug|Any CPU + {867FFC34-DFA7-400F-B9BB-85158326CE08}.Debug|x64.Build.0 = Debug|Any CPU + {867FFC34-DFA7-400F-B9BB-85158326CE08}.Release|Any CPU.ActiveCfg = Release|Any CPU + {867FFC34-DFA7-400F-B9BB-85158326CE08}.Release|Any CPU.Build.0 = Release|Any CPU + {867FFC34-DFA7-400F-B9BB-85158326CE08}.Release|x64.ActiveCfg = Release|Any CPU + {867FFC34-DFA7-400F-B9BB-85158326CE08}.Release|x64.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -925,6 +945,8 @@ Global {AB8D68F1-6C3E-41FD-B0EC-A093E009341D} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} {763FF013-8309-4680-A769-B54E7BB99612} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} {DB2CA055-8ABD-4E3E-8089-5B64C3415E85} = {09EADF06-BE25-4228-AB53-95AE3E15B530} + {694BF884-B2E4-4E1C-9342-0564BAAC4575} = {09EADF06-BE25-4228-AB53-95AE3E15B530} + {867FFC34-DFA7-400F-B9BB-85158326CE08} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D} diff --git a/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj b/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj index 6e1ffed0c9..9387bfbabe 100644 --- a/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj +++ b/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj @@ -1,7 +1,7 @@  - net8.0 + net8.0 false enable preview @@ -19,4 +19,9 @@ + + + + + diff --git a/src/Microsoft.ML.GenAI.Core/Module/GenAILinear.cs b/src/Microsoft.ML.GenAI.Core/Module/GenAILinear.cs new file mode 100644 index 0000000000..c5319ffddf --- /dev/null +++ b/src/Microsoft.ML.GenAI.Core/Module/GenAILinear.cs @@ -0,0 +1,48 @@ +using TorchSharp; +using static TorchSharp.torch; + +namespace Microsoft.ML.GenAI; +internal class GenAILinear : nn.Module +{ +#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format + private readonly Tensor weight; + private readonly Tensor? bias; +#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format + private readonly int _inFeatures; + private readonly int _outFeatures; + + public GenAILinear(int inFeatures, int outFeatures, bool hasBias = true, ScalarType dtype = ScalarType.Float32, string? device = null) + : base(nameof(GenAILinear)) + { + this._inFeatures = inFeatures; + this._outFeatures = outFeatures; + device ??= torch.get_default_device().ToString(); + this.weight = torch.randn(outFeatures, inFeatures, dtype: dtype, device: device); + + if (hasBias) + { + this.bias = torch.randn(outFeatures, dtype: dtype, device: device); + } + + this.RegisterComponents(); + } + +#pragma warning disable MSML_GeneralName // This name should be PascalCased + public override Tensor forward(Tensor input) +#pragma warning restore MSML_GeneralName // This name should be PascalCased + { + using var dispose = torch.NewDisposeScope(); + + // use float32 + var input2 = input.to_type(ScalarType.Float32); + var weight2 = this.weight.to_type(ScalarType.Float32); + var result = torch.matmul(input2, weight2.t()); + + if (this.bias is not null) + { + result = result + this.bias.to_type(ScalarType.Float32); + } + + return result.to_type(input.dtype).MoveToOuterDisposeScope(); + } +} diff --git a/src/Microsoft.ML.GenAI.Core/Module/NewGELUActivation.cs b/src/Microsoft.ML.GenAI.Core/Module/NewGELUActivation.cs new file mode 100644 index 0000000000..a20ad47bd6 --- /dev/null +++ b/src/Microsoft.ML.GenAI.Core/Module/NewGELUActivation.cs @@ -0,0 +1,24 @@ +using System; +using TorchSharp; +using static TorchSharp.torch; + +namespace Microsoft.ML.GenAI; +#pragma warning disable MSML_GeneralName // This name should be PascalCased +internal class NewGELUActivation : torch.nn.Module +#pragma warning disable MSML_GeneralName // This name should be PascalCased +{ + public NewGELUActivation() + : base(nameof(NewGELUActivation)) + { + } + + public override Tensor forward(Tensor input) + { + using var result = 0.044715 * torch.pow(input, 3.0); + using var result2 = result + input; + using var result3 = Math.Sqrt(2.0 / Math.PI) * result2; + using var result4 = torch.tanh(result3); + using var result5 = 1.0 + result4; + return 0.5 * input * result5; + } +} diff --git a/src/Microsoft.ML.GenAI.Core/Pipeline/CasualLMModelOutput.cs b/src/Microsoft.ML.GenAI.Core/Pipeline/CasualLMModelOutput.cs index 10dde68852..f3ab2c5041 100644 --- a/src/Microsoft.ML.GenAI.Core/Pipeline/CasualLMModelOutput.cs +++ b/src/Microsoft.ML.GenAI.Core/Pipeline/CasualLMModelOutput.cs @@ -10,7 +10,7 @@ public class CasualLMModelOutput { public CasualLMModelOutput( Tensor lastHiddenState, - Tensor logits, + Tensor? logits = null, Tensor[]? allHiddenStates = null, Tensor[]? attentions = null, IKVCache? cache = null) @@ -22,7 +22,7 @@ public CasualLMModelOutput( this.Cache = cache; } - public Tensor Logits { get; set; } + public Tensor? Logits { get; set; } public Tensor LastHiddenState { get; set; } diff --git a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs index fa61f7b43a..bc7d3c8e0d 100644 --- a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs +++ b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs @@ -84,7 +84,7 @@ public virtual ( OverrideCache = cache, }; var output = this.Model.forward(input); - logits = output.Logits; + logits = output.Logits ?? throw new InvalidOperationException("Logits is null"); torch.Tensor nextToken; if (temperature > 0) { diff --git a/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj b/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj new file mode 100644 index 0000000000..e313da05ef --- /dev/null +++ b/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj @@ -0,0 +1,22 @@ + + + + net8.0 + enable + enable + + + + + + + + + + + + + + + + diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi2Attention.cs b/src/Microsoft.ML.GenAI.Phi/Module/Phi2Attention.cs new file mode 100644 index 0000000000..7bb45bef3f --- /dev/null +++ b/src/Microsoft.ML.GenAI.Phi/Module/Phi2Attention.cs @@ -0,0 +1,155 @@ +using System.Diagnostics.Contracts; +using TorchSharp; +using TorchSharp.Modules; +using static TorchSharp.torch; + +namespace Microsoft.ML.GenAI.Phi.Module; + +internal class Phi2Attention : nn.Module< + Tensor, // hidden_states + Tensor, // position_ids + Tensor?, // attention_mask + int, // past_key_value_length + bool, // output_attentions + ( + Tensor, // hidden_states, + Tensor?, // attentions, + Tensor? // present_key_value + )> +{ + private readonly int? _layerIdx; + private readonly Phi2Config _config; + private readonly double _attentionDropout; + private readonly int _hiddenSize; + private readonly int _numAttentionHeads; + private readonly int _headDim; + private readonly int _numKeyValueHeads; + private readonly int _numKeyValueGroups; + private readonly int _maxPositionEmbeddings; + private readonly double _ropeTheta; + private readonly double _partialRotaryFactor; + private readonly bool _qkLayernorm; + + // we disable the warning for the private field name not in _camelCase format for all submodules fields + // because their name will be used as keys to load the corresponding weights from the checkpoint +#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format + private readonly GenAILinear q_proj; + private readonly GenAILinear k_proj; + private readonly GenAILinear v_proj; + private readonly GenAILinear dense; + private readonly LayerNorm? q_layernorm; + private readonly LayerNorm? k_layernorm; + + private readonly Phi2RotaryEmbedding phiRotaryEmbedding; + + // cache_k, cache_v + private Tensor cache_k; + private Tensor cache_v; +#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format + + public Phi2Attention(Phi2Config config, int? layerIdx = null, int maxBatch = 2, int maxLength = 1024) + : base(nameof(Phi2Attention)) + { + this._layerIdx = layerIdx; + this._config = config; + this._attentionDropout = config.AttentionDropout; + this._hiddenSize = config.HiddenSize; + this._numAttentionHeads = config.NumAttentionHeads; + this._headDim = this._hiddenSize / this._numAttentionHeads; + this._numKeyValueHeads = config.NumKeyValueHeads ?? throw new ArgumentException("num_key_value_heads must be specified"); + this._numKeyValueGroups = this._numAttentionHeads / this._numKeyValueHeads; + this._maxPositionEmbeddings = config.MaxPositionEmbeddings; + this._ropeTheta = config.RopeTheta; + this._partialRotaryFactor = config.PartialRotaryFactor; + + Contract.Assert(this._hiddenSize % (this._headDim * this._numAttentionHeads) == 0, "hidden_size must be divisible by num_attention_heads"); + this.q_proj = new GenAILinear(this._hiddenSize, this._numAttentionHeads * this._headDim, hasBias: true, dtype: config.Dtype); + this.k_proj = new GenAILinear(this._hiddenSize, this._numKeyValueHeads * this._headDim, hasBias: true, dtype: config.Dtype); + this.v_proj = new GenAILinear(this._hiddenSize, this._numKeyValueHeads * this._headDim, hasBias: true, dtype: config.Dtype); + this.dense = new GenAILinear(this._numAttentionHeads * this._headDim, this._hiddenSize, hasBias: true, dtype: config.Dtype); + + this._qkLayernorm = config.QkLayernorm; + if (this._qkLayernorm) + { + this.q_layernorm = nn.LayerNorm(this._hiddenSize / this._numAttentionHeads, eps: config.LayerNormEps, elementwise_affine: true, dtype: config.Dtype); + this.k_layernorm = nn.LayerNorm(this._hiddenSize / this._numAttentionHeads, eps: config.LayerNormEps, elementwise_affine: true, dtype: config.Dtype); + } + + this.RegisterComponents(); + this.phiRotaryEmbedding = new Phi2RotaryEmbedding( + dim: (int)(this._partialRotaryFactor * this._headDim), + maxPositionEmbeddings: this._maxPositionEmbeddings, + baseValue: this._config.RopeTheta); + this.cache_k = torch.zeros(maxBatch, this._numKeyValueHeads, maxLength, this._headDim, dtype: config.Dtype); + this.cache_v = torch.zeros(maxBatch, this._numKeyValueHeads, maxLength, this._headDim, dtype: config.Dtype); + } + +#pragma warning disable MSML_GeneralName // This name should be PascalCased + public override (Tensor, Tensor?, Tensor?) forward( +#pragma warning restore MSML_GeneralName // This name should be PascalCased + Tensor hiddenStates, + Tensor positionIds, + Tensor? attentionMask = null, + int pastKeyValueLength = 0, + bool outputAttentions = false) + { + // move cache to the same device as hiddenStates + if (this.cache_k.device != hiddenStates.device) + { + this.cache_k = this.cache_k.to(hiddenStates.device, disposeAfter: true).DetachFromDisposeScope(); + this.cache_v = this.cache_v.to(hiddenStates.device, disposeAfter: true).DetachFromDisposeScope(); + } + + using var disposeScope = torch.NewDisposeScope(); + var batchSize = (int)hiddenStates.shape[0]; + var seqLen = (int)hiddenStates.shape[1]; + + var queryStates = this.q_proj.forward(hiddenStates); + var keyStates = this.k_proj.forward(hiddenStates); + var valueStates = this.v_proj.forward(hiddenStates); + if (this._qkLayernorm) + { + queryStates = this.q_layernorm!.forward(queryStates); + keyStates = this.k_layernorm!.forward(keyStates); + } + + queryStates = queryStates.view(batchSize, seqLen, this._numAttentionHeads, this._headDim).transpose_(1, 2); + keyStates = keyStates.view(batchSize, seqLen, this._numKeyValueHeads, this._headDim).transpose_(1, 2); + valueStates = valueStates.view(batchSize, seqLen, this._numKeyValueHeads, this._headDim).transpose_(1, 2); + var kvSeqLen = pastKeyValueLength == 0 ? (int)keyStates.shape[2] : pastKeyValueLength + (int)keyStates.shape[2]; + (var cos, var sin) = this.phiRotaryEmbedding.forward(valueStates, kvSeqLen); + // split the last dim of queryStates and keyStates into rotary and non-rotary parts + // shape: [batch_size, num_heads, seq_len, head_dim] + // queryRot: [batch_size, num_heads, seq_len, :head_dim * partial_rotary_factor] + // queryPass: [batch_size, num_heads, seq_len, head_dim * partial_rotary_factor:] + var keyRot = keyStates[.., .., .., ..this.phiRotaryEmbedding.Dim]; + var keyPass = keyStates[.., .., .., this.phiRotaryEmbedding.Dim..]; + var queryRot = queryStates[.., .., .., ..this.phiRotaryEmbedding.Dim]; + var queryPass = queryStates[.., .., .., this.phiRotaryEmbedding.Dim..]; + (var qRot, var kRot) = Utils.ApplyRotaryPosEmb(queryRot, keyRot, cos, sin, positionIds); + + queryStates = torch.cat([qRot, queryPass], dim: -1); + // update cache + keyStates = torch.cat([kRot, keyPass], dim: -1); + this.cache_k[..batchSize, .., pastKeyValueLength..kvSeqLen, ..] = keyStates; + this.cache_v[..batchSize, .., pastKeyValueLength..kvSeqLen, ..] = valueStates; + keyStates = this.cache_k[..batchSize, .., ..kvSeqLen, ..]; + valueStates = this.cache_v[..batchSize, .., ..kvSeqLen, ..]; + var keyStates2 = Utils.Phi2RepeatKV(keyStates, this._numKeyValueGroups).transpose(2, 3); + var valueStates2 = Utils.Phi2RepeatKV(valueStates, this._numKeyValueGroups); + // Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow + var attnWeights = torch.matmul(queryStates.to_type(float32), keyStates2.to_type(float32)); + attnWeights = attnWeights / Math.Sqrt(this._headDim); + if (attentionMask is not null) + { + attnWeights = attnWeights + attentionMask; + } + attnWeights = nn.functional.softmax(attnWeights, dim: -1); + attnWeights = nn.functional.dropout(attnWeights, p: this._attentionDropout); + var attnOutput = torch.matmul(attnWeights, valueStates2.to_type(float32)).to_type(hiddenStates.dtype); + attnOutput = attnOutput.transpose_(1, 2).contiguous(); + attnOutput = attnOutput.reshape(batchSize, seqLen, this._hiddenSize); + var result = this.dense.forward(attnOutput); + return (result.MoveToOuterDisposeScope(), null, null); + } +} diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi2DecoderLayer.cs b/src/Microsoft.ML.GenAI.Phi/Module/Phi2DecoderLayer.cs new file mode 100644 index 0000000000..f1f87ee079 --- /dev/null +++ b/src/Microsoft.ML.GenAI.Phi/Module/Phi2DecoderLayer.cs @@ -0,0 +1,62 @@ +using TorchSharp; +using TorchSharp.Modules; +using static TorchSharp.torch; + +namespace Microsoft.ML.GenAI.Phi.Module; +public class Phi2DecoderLayer : nn.Module< + Tensor, // hidden_states + Tensor, // position_ids + Tensor?, // attention_mask + int, // past_key_value_length + bool, // use_cache + bool, // output_attentions + ( + Tensor, // hidden_states, + Tensor?, // attentions, + Tensor? // present_key_value + )> +{ + private readonly int? _layerIdx; + +#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format + private readonly Phi2Attention self_attn; + private readonly Phi2MLP mlp; + private readonly LayerNorm input_layernorm; + private readonly Dropout resid_dropout; +#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format + + public Phi2DecoderLayer(Phi2Config config, int? layerIdx = null) + : base(nameof(Phi2DecoderLayer)) + { + this._layerIdx = layerIdx; + this.self_attn = new Phi2Attention(config, layerIdx); + this.mlp = new Phi2MLP(config); + this.input_layernorm = nn.LayerNorm(config.HiddenSize, eps: config.LayerNormEps, dtype: config.Dtype); + this.resid_dropout = nn.Dropout(config.ResidPdrop); + } + +#pragma warning disable MSML_GeneralName // This name should be PascalCased + public override (Tensor, Tensor?, Tensor?) forward( +#pragma warning restore MSML_GeneralName // This name should be PascalCased + Tensor hiddenStates, + Tensor positionIds, + Tensor? attentionMask = null, + int pastKeyValueLength = 0, + bool useCache = false, + bool outputAttentions = false) + { + using var disposiableScope = torch.NewDisposeScope(); + var residual = hiddenStates; + hiddenStates = this.input_layernorm.forward(hiddenStates); + (var attnOutput, var attnWeights, var presentKeyValue) = this.self_attn.forward( + hiddenStates: hiddenStates, + positionIds: positionIds, + attentionMask: attentionMask, + pastKeyValueLength: pastKeyValueLength, + outputAttentions: outputAttentions); + var feedForwardHiddenStates = this.mlp.forward(hiddenStates); + hiddenStates = residual + feedForwardHiddenStates + attnOutput; + + return (hiddenStates.MoveToOuterDisposeScope(), null, null); + } +} diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi2MLP.cs b/src/Microsoft.ML.GenAI.Phi/Module/Phi2MLP.cs new file mode 100644 index 0000000000..8d16bbb152 --- /dev/null +++ b/src/Microsoft.ML.GenAI.Phi/Module/Phi2MLP.cs @@ -0,0 +1,33 @@ +using TorchSharp; +using TorchSharp.Modules; +using static TorchSharp.torch; + +namespace Microsoft.ML.GenAI.Phi.Module; + +#pragma warning disable MSML_GeneralName // This name should be PascalCased +internal class Phi2MLP : torch.nn.Module +#pragma warning restore MSML_GeneralName // This name should be PascalCased +{ +#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format + private readonly GenAILinear fc1; + private readonly GenAILinear fc2; + private readonly torch.nn.Module activation_fn; +#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format + + public Phi2MLP(Phi2Config config) + : base(nameof(Phi2MLP)) + { + this.fc1 = new GenAILinear(config.HiddenSize, config.IntermediateSize, dtype: config.Dtype); + this.fc2 = new GenAILinear(config.IntermediateSize, config.HiddenSize, dtype: config.Dtype); + this.activation_fn = new NewGELUActivation(); + } + +#pragma warning disable MSML_GeneralName // This name should be PascalCased + public override Tensor forward(Tensor input) +#pragma warning restore MSML_GeneralName // This name should be PascalCased + { + using var input1 = this.fc1.forward(input); + using var input2 = this.activation_fn.forward(input1); + return this.fc2.forward(input2); + } +} diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi2Model.cs b/src/Microsoft.ML.GenAI.Phi/Module/Phi2Model.cs new file mode 100644 index 0000000000..2d7ee2d997 --- /dev/null +++ b/src/Microsoft.ML.GenAI.Phi/Module/Phi2Model.cs @@ -0,0 +1,154 @@ +using System.Diagnostics.Contracts; +using TorchSharp; +using TorchSharp.Modules; +using static TorchSharp.torch; + +namespace Microsoft.ML.GenAI.Phi.Module; + +internal class Phi2Model : nn.Module< + Tensor, // input_ids + Tensor?, // attention_mask + int, // past_key_value_length + Tensor?, // position_ids + Tensor?, //input embeddings + ( + bool, // use_cache + bool, // output_attentions + bool // output_hidden_states + ), + ( + Tensor, // hidden_states, + Tensor?, // attentions, + Tensor? // present_key_value + )> +{ + private readonly Phi2Config _config; +#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format + private readonly Embedding embed_tokens; + private readonly Dropout embed_dropout; + private readonly LayerNorm final_layernorm; + private readonly ModuleList layers; +#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format + + public Phi2Model(Phi2Config config) + : base(nameof(Phi2Model)) + { + this._config = config; + this.embed_tokens = nn.Embedding(config.VocabSize, config.HiddenSize, dtype: config.Dtype); + this.embed_dropout = nn.Dropout(config.EmbdPdrop); + this.final_layernorm = nn.LayerNorm(config.HiddenSize, eps: config.LayerNormEps, dtype: config.Dtype); + this.layers = new ModuleList(Enumerable.Range(0, config.NumHiddenLayers).Select(i => new Phi2DecoderLayer(config)).ToArray()); + this.RegisterComponents(); + } + + public Phi2Config Config => this._config; + +#pragma warning disable MSML_GeneralName // This name should be PascalCased + public override (Tensor, Tensor?, Tensor?) forward( +#pragma warning restore MSML_GeneralName // This name should be PascalCased + Tensor inputIds, + Tensor? attentionMask = null, + int pastKeyValueLength = 0, + Tensor? positionIds = null, + Tensor? inputEmbeddings = null, + (bool, bool, bool) options = default) // use_cache, output_attentions, output_hidden_states + { + (var outputAttentions, var outputHiddenStates, var useCache) = options; + + // TODO + // add support for inputEmbeddings + if (inputEmbeddings is not null) + { + throw new NotImplementedException("inputEmbeddings is not supported"); + } + inputEmbeddings = this.embed_tokens.forward(inputIds); + inputEmbeddings = this.embed_dropout.forward(inputEmbeddings); + var batchSize = inputIds.shape[0]; + var seqLen = (int)inputIds.shape[1]; + + if (positionIds is null) + { + positionIds = torch.arange(pastKeyValueLength, seqLen + pastKeyValueLength, dtype: inputIds.dtype, device: inputIds.device); + positionIds = positionIds.unsqueeze(0); + } + + // attention + // use 4d attention mask + if (attentionMask is not null) + { + attentionMask = this.Prepare4DCasualAttentionMask(attentionMask, seqLen, pastKeyValueLength, inputEmbeddings.dtype); + } + + var hiddenStates = inputEmbeddings; + + for (int i = 0; i < this.layers.Count; i++) + { + (hiddenStates, _, _) = this.layers[i].forward( + hiddenStates: hiddenStates, + positionIds: positionIds, + attentionMask: attentionMask, + pastKeyValueLength: pastKeyValueLength, + useCache: useCache, + outputAttentions: outputAttentions); + } + + hiddenStates = this.final_layernorm.forward(hiddenStates); + return (hiddenStates, null, null); + } + + private Tensor Prepare4DCasualAttentionMask( + Tensor attentionMask, + int queryLength, + int pastKeyValueLength, + ScalarType dtype) + { + var batchSize = (int)attentionMask.shape[0]; + var seqLen = attentionMask.shape[1]; + Contract.Assert(seqLen == queryLength, "seqLen must be equal to queryLength"); + var targetLength = queryLength + pastKeyValueLength; + var casual4DMask = this.MakeCasualAttentionMask(batchSize, queryLength, pastKeyValueLength, attentionMask.device, dtype); + var expandedMask = this.ExpandMask(attentionMask, dtype, queryLength).to(attentionMask.device); + + casual4DMask.masked_fill_(expandedMask.to_type(ScalarType.Bool), torch.finfo(dtype).min); + return casual4DMask; + } + + private Tensor ExpandMask( + Tensor mask, + ScalarType dtype, + int targetLength) + { + var batch = mask.shape[0]; + var seqLen = mask.shape[1]; + var expandedMask = mask.unsqueeze(1).unsqueeze(2); + expandedMask = expandedMask.expand(new long[] { batch, 1, targetLength, seqLen }); + expandedMask = expandedMask.to_type(dtype); + + var invertedMask = (1.0f - expandedMask) > 0; + + return invertedMask.masked_fill(invertedMask.to_type(ScalarType.Bool), torch.finfo(dtype).min); + } + private Tensor MakeCasualAttentionMask( + int batchSize, + int targetLen, + int pastKeyValueLength, + Device device, + ScalarType dtype) + { + var mask = torch.full([targetLen, targetLen], torch.finfo(dtype).min, dtype: dtype, device: device); + var maskCond = torch.arange(mask.size(-1), device: device); + mask.masked_fill_(maskCond < (maskCond + 1).view(mask.size(-1), 1), 0.0f); + + mask = mask.to_type(dtype); + + if (pastKeyValueLength > 0) + { + mask = torch.cat([torch.zeros([targetLen, pastKeyValueLength], dtype: dtype, device: device), mask], dim: -1); + } + + mask = mask.unsqueeze(0).unsqueeze(0); + mask = mask.expand(new long[] { batchSize, 1, targetLen, targetLen + pastKeyValueLength }); + + return mask; + } +} diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi2RotaryEmbedding.cs b/src/Microsoft.ML.GenAI.Phi/Module/Phi2RotaryEmbedding.cs new file mode 100644 index 0000000000..ab14c9bb6a --- /dev/null +++ b/src/Microsoft.ML.GenAI.Phi/Module/Phi2RotaryEmbedding.cs @@ -0,0 +1,45 @@ +using TorchSharp; +using static TorchSharp.torch; + +namespace Microsoft.ML.GenAI.Phi.Module; +internal class Phi2RotaryEmbedding : nn.Module< + Tensor, // input + int, // seq_len + ( + Tensor, // cos + Tensor // sin + )> +{ + private readonly double _base; + private readonly int _maxPositionEmbeddings; + private readonly int _dim; + + public Phi2RotaryEmbedding(double baseValue, int maxPositionEmbeddings, int dim) + : base(nameof(Phi2RotaryEmbedding)) + { + _base = baseValue; + _maxPositionEmbeddings = maxPositionEmbeddings; + _dim = dim; + var thetaNumerator = torch.arange(0, _dim, 2, dtype: ScalarType.Int64).to(torch.float32); + this.register_buffer("inv_freq", torch.pow(baseValue, -1.0f * (thetaNumerator / dim)), persistent: false); + } + + public int Dim => _dim; + +#pragma warning disable MSML_GeneralName // This name should be PascalCased + public override (Tensor, Tensor) forward(Tensor x, int seqLen) +#pragma warning restore MSML_GeneralName // This name should be PascalCased + { + // TODO + // can be calculated once and cached + var invFreq = this.get_buffer("inv_freq").to(x.device); + var t = torch.arange(seqLen, dtype: invFreq.dtype, device: invFreq.device); + var freqs = torch.outer(t, invFreq).to(torch.float32); + var emb = torch.cat([freqs, freqs], dim: -1); + + var cos = torch.cos(emb); + var sin = torch.sin(emb); + + return (cos[..seqLen].to_type(x.dtype), sin[..seqLen].to_type(x.dtype)); + } +} diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi3Attention.cs b/src/Microsoft.ML.GenAI.Phi/Module/Phi3Attention.cs new file mode 100644 index 0000000000..d1ffd970d5 --- /dev/null +++ b/src/Microsoft.ML.GenAI.Phi/Module/Phi3Attention.cs @@ -0,0 +1,192 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics.Contracts; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.ML.GenAI.Core; +using TorchSharp; +using TorchSharp.Modules; +using static TorchSharp.torch; + +namespace Microsoft.ML.GenAI.Phi.Module; + +internal class Phi3AttentionInput +{ + public Phi3AttentionInput( + Tensor hiddenStates, + Tensor positionIds, + Tensor? attentionMask = null, + IKVCache? cache = null, + bool outputAttentions = false) + { + this.HiddenStates = hiddenStates; + this.AttentionMask = attentionMask; + this.PositionIds = positionIds; + this.Cache = cache; + this.OutputAttentions = outputAttentions; + } + public Tensor HiddenStates { get; set; } + + public Tensor? AttentionMask { get; set; } + + public Tensor PositionIds { get; set; } + + public IKVCache? Cache { get; set; } + + public bool OutputAttentions { get; set; } +} + +internal class Phi3AttentionOutput +{ + public Phi3AttentionOutput( + Tensor hiddenStates, + Tensor? attentions = null, + IKVCache? cache = null) + { + this.HiddenStates = hiddenStates; + this.Attentions = attentions; + this.Cache = cache; + } + + public Tensor HiddenStates { get; set; } + + public Tensor? Attentions { get; set; } + + public IKVCache? Cache { get; set; } +} + +internal class Phi3Attention : nn.Module +{ + private readonly Phi3Config _config; + private readonly int _layerIdx; + private readonly double _attentionDropout; + private readonly int _hiddenSize; + private readonly int _numHeads; + private readonly int _headDim; + private readonly int _numKeyValueHeads; + private readonly int _numKeyValueGroups; + private readonly int _maxPositionEmbeddings; + private readonly int _originalMaxPositionEmbeddings; + private readonly double _ropeTheta; + private readonly Dictionary? _ropeScaling; +#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format + private readonly GenAILinear o_proj; + private readonly GenAILinear qkv_proj; + private nn.Module rotary_emb = null!; +#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format + + public Phi3Attention(Phi3Config config, int layerIdx) + : base(nameof(Phi3Attention)) + { + this._config = config; + this._layerIdx = layerIdx; + this._attentionDropout = config.AttentionDropout; + this._hiddenSize = config.HiddenSize; + this._numHeads = config.NumAttentionHeads; + this._headDim = this._hiddenSize / this._numHeads; + this._numKeyValueHeads = config.NumKeyValueHeads ?? throw new ArgumentException("num_key_value_heads must be specified"); + this._numKeyValueGroups = this._numHeads / this._numKeyValueHeads; + this._maxPositionEmbeddings = config.MaxPositionEmbeddings; + this._originalMaxPositionEmbeddings = config.OriginalMaxPositionEmbeddings; + this._ropeTheta = config.RopeTheta; + this._ropeScaling = config.RopeScaling; + + Contract.Assert(this._hiddenSize % (this._headDim * this._numHeads) == 0, "hidden_size must be divisible by num_heads"); + + var opSize = this._numHeads * this._headDim + 2 * (this._numKeyValueHeads * this._headDim); + this.o_proj = new GenAILinear(this._numHeads * this._headDim, this._hiddenSize, hasBias: false, dtype: config.DType); + this.qkv_proj = new GenAILinear(this._hiddenSize, opSize, hasBias: false, dtype: config.DType); + this.InitRope(); + } + + private void InitRope() + { + if (this._ropeScaling is null) + { + this.rotary_emb = new Phi3RotaryEmbedding(this._ropeTheta, this._maxPositionEmbeddings, this._headDim); + } + else + { + this.rotary_emb = new Phi3SuScaledRotaryEmbedding(this._headDim, this._config); + } + } + +#pragma warning disable MSML_GeneralName // This name should be PascalCased + public override Phi3AttentionOutput forward(Phi3AttentionInput input) +#pragma warning restore MSML_GeneralName // This name should be PascalCased + { + using (var _ = NewDisposeScope()) + { + var hiddenStates = input.HiddenStates; + var positionIds = input.PositionIds; + var outputAttentions = input.OutputAttentions; + var bsz = hiddenStates.shape[0]; + var qLen = hiddenStates.shape[1]; + + var qkv = this.qkv_proj.forward(hiddenStates); + var queryPos = this._numHeads * this._headDim; + var queryStates = qkv[.., .., ..queryPos]; + var keyStates = qkv[.., .., queryPos..(queryPos + this._numKeyValueHeads * this._headDim)]; + var valueStates = qkv[.., .., (queryPos + this._numKeyValueHeads * this._headDim)..]; + queryStates = queryStates.view(bsz, qLen, this._numHeads, this._headDim).transpose(1, 2); + keyStates = keyStates.view(bsz, qLen, this._numKeyValueHeads, this._headDim).transpose(1, 2); + valueStates = valueStates.view(bsz, qLen, this._numKeyValueHeads, this._headDim).transpose(1, 2); + + var kvSeqLen = keyStates.IntShape()[^2]; + var pastKeyValue = input.Cache; + if (pastKeyValue is not null) + { + kvSeqLen += pastKeyValue.GetUsableLength(kvSeqLen, this._layerIdx); + } + + var embOutput = this.rotary_emb.forward(new Phi3RotaryEmbeddingInput(valueStates, positionIds, kvSeqLen)); + (var cos, var sin) = (embOutput.Cos, embOutput.Sin); + + (queryStates, keyStates) = Utils.ApplyRotaryPosEmb(queryStates, keyStates, cos, sin); + + if (pastKeyValue is not null) + { + (keyStates, valueStates) = pastKeyValue.UpdateKVCache(keyStates, valueStates, this._layerIdx); + } + + // repeat k/v heads if n_kv_heads < n_heads + keyStates = Utils.Phi3RepeatKV(keyStates, this._numKeyValueGroups); + valueStates = Utils.Phi3RepeatKV(valueStates, this._numKeyValueGroups); + + var attnWeights = torch.matmul(queryStates, keyStates.transpose(2, 3)); + attnWeights = attnWeights / Math.Sqrt(this._headDim); + + // attnWeight's shape should be [bsz, this._numHeads, qLen, kvSeqLen] + Contract.Assert(attnWeights.shape.Length == 4); + Contract.Assert(attnWeights.shape[0] == bsz); + Contract.Assert(attnWeights.shape[1] == this._numHeads); + Contract.Assert(attnWeights.shape[2] == qLen); + Contract.Assert(attnWeights.shape[3] == kvSeqLen); + + var attentionMask = input.AttentionMask; + if (attentionMask is not null) + { + Contract.Assert(attentionMask.shape.Length == 4); + Contract.Assert(attentionMask.shape[0] == bsz); + Contract.Assert(attentionMask.shape[1] == 1); + Contract.Assert(attentionMask.shape[2] == qLen); + Contract.Assert(attentionMask.shape[3] == kvSeqLen); + attnWeights = attnWeights + attentionMask; + } + + // upscale attention to fp32 to avoid overflow + attnWeights = nn.functional.softmax(attnWeights, dim: -1, dtype: ScalarType.Float32).to(valueStates.dtype); + attnWeights = nn.functional.dropout(attnWeights, this._attentionDropout, this.training); + + var attnOutput = torch.matmul(attnWeights, valueStates); + + attnOutput = attnOutput.transpose(1, 2).contiguous(); + attnOutput = attnOutput.reshape(bsz, qLen, this._hiddenSize); + + attnOutput = this.o_proj.forward(attnOutput); + + return new(attnOutput.MoveToOuterDisposeScope(), outputAttentions ? attnWeights.MoveToOuterDisposeScope() : null, pastKeyValue); + } + } +} diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi3DecoderLayer.cs b/src/Microsoft.ML.GenAI.Phi/Module/Phi3DecoderLayer.cs new file mode 100644 index 0000000000..55a4700db9 --- /dev/null +++ b/src/Microsoft.ML.GenAI.Phi/Module/Phi3DecoderLayer.cs @@ -0,0 +1,125 @@ +using Microsoft.ML.GenAI.Core; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using TorchSharp.Modules; +using static TorchSharp.torch; + +namespace Microsoft.ML.GenAI.Phi.Module; + +internal class Phi3DecoderLayerInput +{ + public Phi3DecoderLayerInput( + Tensor hiddenStates, + Tensor attentionMask, + Tensor positionIds, + IKVCache? pastKeyValue = null, + bool outputAttentions = false) + { + this.HiddenStates = hiddenStates; + this.AttentionMask = attentionMask; + this.PositionIds = positionIds; + this.PastKeyValue = pastKeyValue; + this.OutputAttentions = outputAttentions; + } + + public Tensor HiddenStates { get; set; } + + public Tensor AttentionMask { get; set; } + + public Tensor PositionIds { get; set; } + + public IKVCache? PastKeyValue { get; set; } + + public bool OutputAttentions { get; set; } +} + +internal class Phi3DecoderLayerOutput +{ + public Phi3DecoderLayerOutput( + Tensor hiddenStates, + Tensor? attentions = null, + IKVCache? pastKeyValue = null) + { + this.HiddenStates = hiddenStates; + this.Attentions = attentions; + this.PastKeyValue = pastKeyValue; + } + + public Tensor HiddenStates { get; set; } + + public Tensor? Attentions { get; set; } + + public IKVCache? PastKeyValue { get; set; } +} + +internal class Phi3DecoderLayer : nn.Module, IDynamicLoadModule +{ + private readonly Phi3Config _config; +#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format + private readonly nn.Module self_attn; + private readonly Phi3MLP mlp; + private readonly Phi3RMSNorm input_layernorm; + private readonly Dropout resid_attn_dropout; + private readonly Dropout resid_mlp_dropout; + private readonly Phi3RMSNorm post_attention_layernorm; +#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format + + public Phi3DecoderLayer(Phi3Config config, int layerIdx) + : base(nameof(Phi3DecoderLayer)) + { + this._config = config; + if (config.AttnImplementation == "eager") + { + this.self_attn = new Phi3Attention(config, layerIdx); + } + else + { + throw new NotImplementedException(); + } + + this.mlp = new Phi3MLP(config); + this.input_layernorm = new Phi3RMSNorm(config.HiddenSize, config.RmsNormEps, config.DType); + + this.resid_attn_dropout = nn.Dropout(config.ResidPdrop); + this.resid_mlp_dropout = nn.Dropout(config.ResidPdrop); + this.post_attention_layernorm = new Phi3RMSNorm(config.HiddenSize, config.RmsNormEps, config.DType); + } + + public Action? LoadToDeviceFunc { get; set; } + + public Action? UnloadFromDeviceFunc { get; set; } + +#pragma warning disable MSML_GeneralName // This name should be PascalCased + public override Phi3DecoderLayerOutput forward(Phi3DecoderLayerInput input) +#pragma warning restore MSML_GeneralName // This name should be PascalCased + { + if (LoadToDeviceFunc != null) + { + LoadToDeviceFunc(this); + } + using var disposeScope = NewDisposeScope(); + var hiddenStates = input.HiddenStates; + var residual = input.HiddenStates; + hiddenStates = this.input_layernorm.forward(hiddenStates); + + var attentionInput = new Phi3AttentionInput(hiddenStates, input.PositionIds, input.AttentionMask, input.PastKeyValue, input.OutputAttentions); + var output = this.self_attn.forward(attentionInput); + var attnOutputs = output.HiddenStates; + var selfAttnWeights = output.Attentions; + var presentKeyValue = output.Cache; + hiddenStates = residual + this.resid_attn_dropout.forward(attnOutputs); + residual = hiddenStates; + hiddenStates = this.post_attention_layernorm.forward(hiddenStates); + hiddenStates = this.mlp.forward(hiddenStates); + hiddenStates = residual + this.resid_mlp_dropout.forward(hiddenStates); + + if (UnloadFromDeviceFunc != null) + { + UnloadFromDeviceFunc(this); + } + return new Phi3DecoderLayerOutput(hiddenStates.MoveToOuterDisposeScope(), selfAttnWeights?.MoveToOuterDisposeScope(), presentKeyValue); + } +} diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi3MLP.cs b/src/Microsoft.ML.GenAI.Phi/Module/Phi3MLP.cs new file mode 100644 index 0000000000..abec0d78cf --- /dev/null +++ b/src/Microsoft.ML.GenAI.Phi/Module/Phi3MLP.cs @@ -0,0 +1,46 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using static TorchSharp.torch; +using TorchSharp.Modules; +using TorchSharp; + +namespace Microsoft.ML.GenAI.Phi.Module; +#pragma warning disable MSML_GeneralName // This name should be PascalCased +internal class Phi3MLP : torch.nn.Module +#pragma warning restore MSML_GeneralName // This name should be PascalCased +{ +#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format + private readonly GenAILinear gate_up_proj; + private readonly GenAILinear down_proj; + private readonly torch.nn.Module activation_fn; +#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format + + public Phi3MLP(Phi3Config config) + : this(config.HiddenSize, config.IntermediateSize, config.HiddenAct, config.DType) + { + } + + public Phi3MLP(int hiddenSize, int intermediateSize, string hiddenAct, ScalarType dtype) + : base(nameof(Phi3MLP)) + { + this.gate_up_proj = new GenAILinear(hiddenSize, 2 * intermediateSize, hasBias: false, dtype: dtype); + this.down_proj = new GenAILinear(intermediateSize, hiddenSize, hasBias: false, dtype: dtype); + this.RegisterComponents(); + this.activation_fn = Utils.GetActivation(hiddenAct); + } + +#pragma warning disable MSML_GeneralName // This name should be PascalCased + public override Tensor forward(Tensor input) +#pragma warning restore MSML_GeneralName // This name should be PascalCased + { + using var input1 = this.gate_up_proj.forward(input); + var chunks = input1.chunk(2, dim: -1); + var gate = chunks[0]; + var upStatus = chunks[1]; + upStatus = upStatus * this.activation_fn.forward(gate); + return this.down_proj.forward(upStatus); + } +} diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi3Model.cs b/src/Microsoft.ML.GenAI.Phi/Module/Phi3Model.cs new file mode 100644 index 0000000000..9f9f0a17ab --- /dev/null +++ b/src/Microsoft.ML.GenAI.Phi/Module/Phi3Model.cs @@ -0,0 +1,134 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.GenAI.Core; +using TorchSharp; +using TorchSharp.Modules; +using static TorchSharp.torch; + +namespace Microsoft.ML.GenAI.Phi.Module; + +internal class Phi3Model : nn.Module +{ + private readonly Phi3Config _config; + private readonly int _paddingIdx; + private readonly int _vocabSize; + private IKVCache _cache; +#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format + private readonly Embedding embed_tokens; + private readonly Dropout embed_dropout; + private readonly ModuleList layers; + private readonly Phi3RMSNorm norm; +#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format + + public Phi3Model(Phi3Config config) + : base(nameof(Phi3Model)) + { + this._config = config; + this._paddingIdx = config.PadTokenId ?? 32000; + this._vocabSize = config.VocabSize; + + this.embed_tokens = nn.Embedding(config.VocabSize, config.HiddenSize, padding_idx: this._paddingIdx, dtype: config.DType); + this.embed_dropout = nn.Dropout(config.EmbdPdrop); + this.layers = new ModuleList(); + + for (int i = 0; i < config.NumHiddenLayers; i++) + { + this.layers.Add(new Phi3DecoderLayer(config, i)); + } + this.norm = new Phi3RMSNorm(config.HiddenSize, config.RmsNormEps, config.DType); + this._cache = new DynamicKVCache(); + this.RegisterComponents(); + } +#pragma warning disable MSML_GeneralName // This name should be PascalCased + public override CasualLMModelOutput forward(CasualLMModelInput input) +#pragma warning restore MSML_GeneralName // This name should be PascalCased + { + if (input.OverrideCache is not null) + { + this._cache = input.OverrideCache; + } + + var outputAttentions = input.OutputAttentions; + var outputHiddenStates = input.OutputHiddenStates; + var attentionMask = input.AttentionMask; + Device device; + var inputIds = input.InputIds; + var positionIds = input.PositionIds; + var inputsEmbeds = input.InputEmbeddings; + int batchSize; + int seqLength; + if (inputIds is not null && inputsEmbeds is not null) + { + throw new ArgumentException("Only one of input_ids or inputs_embeds may be set"); + } + else if (inputIds is not null) + { + batchSize = inputIds.IntShape()[0]; + seqLength = inputIds.IntShape()[1]; + inputsEmbeds = this.embed_tokens.forward(inputIds); + device = inputIds.device; + } + else if (inputsEmbeds is not null) + { + batchSize = inputsEmbeds.IntShape()[0]; + seqLength = inputsEmbeds.IntShape()[1]; + device = inputsEmbeds.device; + } + else + { + throw new ArgumentException("Either input_ids or inputs_embeds must be set"); + } + + var pastKeyValuesLength = input.PastKeyValuesLength; + + if (positionIds is null) + { + positionIds = torch.arange(pastKeyValuesLength, seqLength + pastKeyValuesLength, device: device); + positionIds = positionIds.unsqueeze(0).view(-1, seqLength); + } + else + { + positionIds = ((long)positionIds.view(-1, seqLength)); + } + + if (this._config.AttnImplementation == "flash_attention_2") + { + throw new NotImplementedException(); + } + else + { + attentionMask = AttentionMaskConverter.Create4DCausalAttentionMask(attentionMask, [batchSize, seqLength], inputsEmbeds.dtype, device, pastKeyValuesLength, this._config.SlidingWindow); + } + + var hiddenStates = inputsEmbeds; + + var allHiddenStates = new List(); + var allAttentions = new List(); + + foreach (var layer in this.layers) + { + if (outputHiddenStates) + { + allHiddenStates.Add(hiddenStates); + } + + var decoderInput = new Phi3DecoderLayerInput(hiddenStates, attentionMask!, positionIds, this._cache, outputAttentions); + var layerOutput = layer.forward(decoderInput); + hiddenStates = layerOutput.HiddenStates; + if (outputAttentions && layerOutput.Attentions is not null) + { + allAttentions.Add(layerOutput.Attentions); + } + } + + hiddenStates = this.norm.forward(hiddenStates); + if (outputHiddenStates) + { + allHiddenStates.Add(hiddenStates); + } + + return new CasualLMModelOutput(lastHiddenState: hiddenStates, allHiddenStates: allHiddenStates.ToArray(), attentions: allAttentions.ToArray(), cache: this._cache); + } +} diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi3RMSNorm.cs b/src/Microsoft.ML.GenAI.Phi/Module/Phi3RMSNorm.cs new file mode 100644 index 0000000000..23cfab24ba --- /dev/null +++ b/src/Microsoft.ML.GenAI.Phi/Module/Phi3RMSNorm.cs @@ -0,0 +1,55 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using static TorchSharp.torch; +using TorchSharp; +using TorchSharp.Modules; + +namespace Microsoft.ML.GenAI.Phi.Module; + +#pragma warning disable MSML_GeneralName // This name should be PascalCased +internal class Phi3RMSNorm : torch.nn.Module +#pragma warning restore MSML_GeneralName // This name should be PascalCased +{ + private readonly int _dim; + private readonly float _eps; +#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format + private readonly Parameter weight; +#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format + + public Phi3RMSNorm( + int hiddenSize, + float eps = 1e-6f, + ScalarType dtype = ScalarType.Float32) + : base(nameof(Phi3RMSNorm)) + { + this._dim = hiddenSize; + this._eps = eps; + + // the gamma scalar + this.weight = torch.nn.Parameter(torch.ones(this._dim, dtype: dtype)); + } + + private Tensor Norm(Tensor x) + { + // (B, Seq_Len, Dim) * (B, Seq_Len, 1) = (B, Seq_Len, Dim) + // rsqrt = 1 / sqrt + var output = x * torch.rsqrt(x.pow(2).mean([-1L], keepdim: true) + this._eps); + return output; + } + +#pragma warning disable MSML_GeneralName // This name should be PascalCased + public override Tensor forward(Tensor input) +#pragma warning restore MSML_GeneralName // This name should be PascalCased + { + // needs higher precision for the norm so convert to float32 + // (B, Seq_Len, Dim) + var normed = this.Norm(input.to_type(ScalarType.Float32)).type_as(input); + // (B, Seq_Len, Dim) * (Dim) = (B, Seq_Len, Dim) + var output = this.weight * normed; + + return output; + } +} diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi3RotaryEmbedding.cs b/src/Microsoft.ML.GenAI.Phi/Module/Phi3RotaryEmbedding.cs new file mode 100644 index 0000000000..226d9b8d14 --- /dev/null +++ b/src/Microsoft.ML.GenAI.Phi/Module/Phi3RotaryEmbedding.cs @@ -0,0 +1,77 @@ +using TorchSharp; +using static TorchSharp.torch; + +namespace Microsoft.ML.GenAI.Phi.Module; +internal class Phi3RotaryEmbeddingInput +{ + public Phi3RotaryEmbeddingInput(Tensor input, Tensor positionIds, int? seqLen = null) + { + Input = input; + PositionIds = positionIds; + SeqLen = seqLen; + } + + public Tensor Input { get; set; } + + public Tensor PositionIds { get; set; } + + public int? SeqLen { get; set; } +} + +internal class Phi3RotaryEmbeddingOutput +{ + public Phi3RotaryEmbeddingOutput(Tensor cos, Tensor sin) + { + Cos = cos; + Sin = sin; + } + + public Tensor Cos { get; set; } + + public Tensor Sin { get; set; } +} + + +internal class Phi3RotaryEmbedding : nn.Module< + Phi3RotaryEmbeddingInput, + Phi3RotaryEmbeddingOutput> +{ + private readonly double _base; + private readonly int _maxPositionEmbeddings; + private readonly int _dim; + + public Phi3RotaryEmbedding(double baseValue, int maxPositionEmbeddings, int dim) + : base(nameof(Phi3RotaryEmbedding)) + { + _base = baseValue; + _maxPositionEmbeddings = maxPositionEmbeddings; + _dim = dim; + var thetaNumerator = torch.arange(0, _dim, 2, dtype: ScalarType.Int64).to(torch.float32); + this.register_buffer("inv_freq", torch.pow(baseValue, -1.0f * (thetaNumerator / dim)), persistent: false); + } + + public int Dim => _dim; + +#pragma warning disable MSML_GeneralName // This name should be PascalCased + public override Phi3RotaryEmbeddingOutput forward(Phi3RotaryEmbeddingInput input) +#pragma warning restore MSML_GeneralName // This name should be PascalCased + { + var x = input.Input; + var positionIds = input.PositionIds; + var seqLen = input.SeqLen; + // TODO + // can be calculated once and cached + var invFreq = this.get_buffer("inv_freq").to(x.device); + var invFreqExpanded = invFreq.unsqueeze(0).unsqueeze(-1); + invFreqExpanded = invFreqExpanded.expand(new long[] { positionIds.shape[0], -1, 1 }); + var positionIdsExpanded = positionIds.unsqueeze(1).to(torch.float32); + var freqs = invFreqExpanded * positionIdsExpanded; + freqs = freqs.transpose(1, 2); + var emb = torch.cat([freqs, freqs], dim: -1); + + var cos = torch.cos(emb); + var sin = torch.sin(emb); + + return new(cos.to_type(x.dtype), sin.to_type(x.dtype)); + } +} diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi3SuScaledRotaryEmbedding.cs b/src/Microsoft.ML.GenAI.Phi/Module/Phi3SuScaledRotaryEmbedding.cs new file mode 100644 index 0000000000..7084c15839 --- /dev/null +++ b/src/Microsoft.ML.GenAI.Phi/Module/Phi3SuScaledRotaryEmbedding.cs @@ -0,0 +1,75 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Text.Json; +using System.Threading.Tasks; +using TorchSharp; +using static TorchSharp.torch; + +namespace Microsoft.ML.GenAI.Phi.Module; + +internal class Phi3SuScaledRotaryEmbedding : Phi3RotaryEmbedding +{ + private readonly double[] _shortFactor; + private readonly double[] _longFactor; + private readonly int _originalMaxPositionEmbeddings; + private readonly int _maxPositionEmbeddings; + private readonly double _base; + + public Phi3SuScaledRotaryEmbedding(int dim, Phi3Config config) + : base(config.RopeTheta, config.MaxPositionEmbeddings, dim) + { + JsonElement shortFactorElement = (JsonElement)config.RopeScaling!["short_factor"]; + JsonElement longFactorDocument = (JsonElement)config.RopeScaling!["long_factor"]; + this._shortFactor = shortFactorElement.EnumerateArray().Select(e => e.GetDouble()).ToArray(); + this._longFactor = longFactorDocument.EnumerateArray().Select(e => e.GetDouble()).ToArray(); + + this._originalMaxPositionEmbeddings = config.OriginalMaxPositionEmbeddings; + this._maxPositionEmbeddings = config.MaxPositionEmbeddings; + this._base = config.RopeTheta; + } + +#pragma warning disable MSML_GeneralName // This name should be PascalCased + public override Phi3RotaryEmbeddingOutput forward(Phi3RotaryEmbeddingInput input) +#pragma warning restore MSML_GeneralName // This name should be PascalCased + { + var seqLen = (torch.max(input.PositionIds) + 1).ToInt32(); + var x = input.Input; + Tensor extFactors; + if (seqLen > this._originalMaxPositionEmbeddings) + { + extFactors = torch.tensor(this._longFactor, dtype: ScalarType.Float32, x.device); + } + else + { + extFactors = torch.tensor(this._shortFactor, dtype: ScalarType.Float32, x.device); + } + var invFreqShape = torch.arange(0, this.Dim, 2, dtype: ScalarType.Int64).to(torch.float32) / this.Dim; + invFreqShape = invFreqShape.to(x.device); + var invFreq = 1.0f / (torch.pow(this._base, invFreqShape) * extFactors); + + var invFreqExpanded = invFreq.unsqueeze(0).unsqueeze(-1); + invFreqExpanded = invFreqExpanded.expand(new long[] { input.PositionIds.shape[0], -1, 1 }); + var positionIdsExpanded = input.PositionIds.unsqueeze(1).to(torch.float32); + + var freqs = invFreqExpanded * positionIdsExpanded; + freqs = freqs.transpose(1, 2); + var emb = torch.cat([freqs, freqs], dim: -1); + var scale = (1.0 * this._maxPositionEmbeddings) / this._originalMaxPositionEmbeddings; + double scalingFactor; + if (scale <= 1) + { + scalingFactor = 1.0; + } + else + { + scalingFactor = Math.Sqrt(1 + Math.Log(scale) / Math.Log(this._originalMaxPositionEmbeddings)); + } + + var cos = torch.cos(emb) * scalingFactor; + var sin = torch.sin(emb) * scalingFactor; + + return new(cos.to_type(x.dtype), sin.to_type(x.dtype)); + } +} diff --git a/src/Microsoft.ML.GenAI.Phi/Phi2/Config.cs b/src/Microsoft.ML.GenAI.Phi/Phi2/Config.cs new file mode 100644 index 0000000000..2727321f6b --- /dev/null +++ b/src/Microsoft.ML.GenAI.Phi/Phi2/Config.cs @@ -0,0 +1,101 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Text.Json.Serialization; +using System.Threading.Tasks; +using static TorchSharp.torch; + +namespace Microsoft.ML.GenAI.Phi; + +public class Phi2Config +{ + public Phi2Config() + { + this.VocabSize = 51200; + this.HiddenSize = 2048; + this.IntermediateSize = 8192; + this.NumHiddenLayers = 24; + this.NumAttentionHeads = 32; + this.ResidPdrop = 0.0; + this.EmbdPdrop = 0.0; + this.AttentionDropout = 0.0; + this.HiddenAct = "gelu_new"; + this.MaxPositionEmbeddings = 2048; + this.InitializerRange = 0.02; + this.LayerNormEps = 1e-5; + this.UseCache = true; + this.TieWordEmbeddings = false; + this.RopeTheta = 10000.0; + this.PartialRotaryFactor = 0.5; + this.QkLayernorm = false; + this.BosTokenId = 1; + this.EosTokenId = 2; + this.Dtype = ScalarType.Float32; + } + + [JsonPropertyName("vocab_size")] + public int VocabSize { get; set; } + + [JsonPropertyName("hidden_size")] + public int HiddenSize { get; set; } + + [JsonPropertyName("intermediate_size")] + public int IntermediateSize { get; set; } + + [JsonPropertyName("num_hidden_layers")] + public int NumHiddenLayers { get; set; } + + [JsonPropertyName("num_attention_heads")] + public int NumAttentionHeads { get; set; } + + [JsonPropertyName("num_key_value_heads")] + public int? NumKeyValueHeads { get; set; } + + [JsonPropertyName("resid_pdrop")] + public double ResidPdrop { get; set; } + + [JsonPropertyName("embd_pdrop")] + public double EmbdPdrop { get; set; } + + [JsonPropertyName("attention_dropout")] + public double AttentionDropout { get; set; } + + [JsonPropertyName("hidden_act")] + public string HiddenAct { get; set; } + + [JsonPropertyName("max_position_embeddings")] + public int MaxPositionEmbeddings { get; set; } + + [JsonPropertyName("initializer_range")] + public double InitializerRange { get; set; } + + [JsonPropertyName("layer_norm_eps")] + public double LayerNormEps { get; set; } + + [JsonPropertyName("use_cache")] + public bool UseCache { get; set; } + + [JsonPropertyName("tie_word_embeddings")] + public bool TieWordEmbeddings { get; set; } + + [JsonPropertyName("rope_theta")] + public double RopeTheta { get; set; } + + // [JsonPropertyName("rope_scaling")] + // public double? RopeScaling { get; set; } = null; + + [JsonPropertyName("partial_rotary_factor")] + public double PartialRotaryFactor { get; set; } + + [JsonPropertyName("qk_layernorm")] + public bool QkLayernorm { get; set; } + + [JsonPropertyName("bos_token_id")] + public int BosTokenId { get; set; } + + [JsonPropertyName("eos_token_id")] + public int EosTokenId { get; set; } + + public ScalarType Dtype { get; set; } +} diff --git a/src/Microsoft.ML.GenAI.Phi/Phi2/Phi2Extension.cs b/src/Microsoft.ML.GenAI.Phi/Phi2/Phi2Extension.cs new file mode 100644 index 0000000000..9c3043a03b --- /dev/null +++ b/src/Microsoft.ML.GenAI.Phi/Phi2/Phi2Extension.cs @@ -0,0 +1,15 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Microsoft.ML.GenAI.Phi.Extension; + +public static class Phi2Extension +{ +} diff --git a/src/Microsoft.ML.GenAI.Phi/Phi2/Phi2ForCasualLM.cs b/src/Microsoft.ML.GenAI.Phi/Phi2/Phi2ForCasualLM.cs new file mode 100644 index 0000000000..af65f00ff2 --- /dev/null +++ b/src/Microsoft.ML.GenAI.Phi/Phi2/Phi2ForCasualLM.cs @@ -0,0 +1,64 @@ +using System.CodeDom; +using System.Text.Json; +using System.Text.Json.Serialization; +using Microsoft.ML.GenAI.Core; +using Microsoft.ML.GenAI.Phi.Module; +using TorchSharp; +using TorchSharp.Modules; +using TorchSharp.PyBridge; +using static TorchSharp.torch; + +namespace Microsoft.ML.GenAI.Phi; + +public class Phi2ForCasualLM : nn.Module +{ +#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format + private readonly Phi2Model model; + private readonly GenAILinear lm_head; +#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format + + public Phi2ForCasualLM(Phi2Config config) + : base(nameof(Phi2ForCasualLM)) + { + this.model = new Phi2Model(config); + this.lm_head = new GenAILinear(config.HiddenSize, config.VocabSize, dtype: config.Dtype); + this.RegisterComponents(); + } + +#pragma warning disable MSML_GeneralName // This name should be PascalCased + public override CasualLMModelOutput forward(CasualLMModelInput input) // use_cache, output_attentions, output_hidden_states +#pragma warning restore MSML_GeneralName // This name should be PascalCased + { + var inputIds = input.InputIds; + var attentionMask = input.AttentionMask; + var pastKeyValueLength = input.PastKeyValuesLength; + var positionIds = input.PositionIds; + var inputEmbeddings = input.InputEmbeddings; + var options = (input.OutputAttentions, input.OutputHiddenStates, false); + var output = this.model.forward(inputIds, attentionMask, pastKeyValueLength, positionIds, inputEmbeddings, options); + var hiddenState = output.Item1; + + var lmLogits = this.lm_head.forward(hiddenState); + + return new CasualLMModelOutput(lastHiddenState: hiddenState, logits: lmLogits); + } + + public static Phi2ForCasualLM FromPretrained( + string modelFolder, + string configName = "config.json", + string checkPointName = "model.safetensors.index.json", + ScalarType torchDtype = ScalarType.Float32, + bool useTqdm = false, + string? device = null) + { + var config = Path.Join(modelFolder, configName); + var modelConfig = JsonSerializer.Deserialize(File.ReadAllText(config)) ?? throw new ArgumentNullException(nameof(config)); + modelConfig.Dtype = torchDtype; + var wrapper = new Phi2ForCasualLM(modelConfig); + var loadedParameters = new Dictionary(); + wrapper.load_checkpoint(path: modelFolder, checkpointName: checkPointName, strict: true, loadedParameters: loadedParameters, useTqdm: useTqdm); + wrapper = wrapper.to(device); + wrapper.eval(); + return wrapper; + } +} diff --git a/src/Microsoft.ML.GenAI.Phi/Phi3/Config.cs b/src/Microsoft.ML.GenAI.Phi/Phi3/Config.cs new file mode 100644 index 0000000000..2e5f755d95 --- /dev/null +++ b/src/Microsoft.ML.GenAI.Phi/Phi3/Config.cs @@ -0,0 +1,113 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Text.Json.Serialization; +using System.Threading.Tasks; +using static TorchSharp.torch; + +namespace Microsoft.ML.GenAI.Phi; +public class Phi3Config +{ + public Phi3Config() + { + this.VocabSize = 32064; + this.HiddenSize = 3072; + this.RmsNormEps = 1e-5f; + this.IntermediateSize = 8192; + this.NumHiddenLayers = 32; + this.NumAttentionHeads = 32; + this.ResidPdrop = 0.0; + this.EmbdPdrop = 0.0; + this.AttentionDropout = 0.0; + this.HiddenAct = "silu"; + this.MaxPositionEmbeddings = 4096; + this.OriginalMaxPositionEmbeddings = 4096; + this.InitializerRange = 0.02; + this.UseCache = true; + this.TieWordEmbeddings = false; + this.RopeTheta = 10000.0; + this.PartialRotaryFactor = 0.5; + this.QkLayernorm = false; + this.BosTokenId = 1; + this.EosTokenId = 32000; + this.DType = ScalarType.BFloat16; + this.AttnImplementation = "eager"; + } + + [JsonPropertyName("vocab_size")] + public int VocabSize { get; set; } + + [JsonPropertyName("hidden_size")] + public int HiddenSize { get; set; } + + [JsonPropertyName("rms_norm_eps")] + public float RmsNormEps { get; set; } + + [JsonPropertyName("intermediate_size")] + public int IntermediateSize { get; set; } + + [JsonPropertyName("num_hidden_layers")] + public int NumHiddenLayers { get; set; } + + [JsonPropertyName("num_attention_heads")] + public int NumAttentionHeads { get; set; } + + [JsonPropertyName("num_key_value_heads")] + public int? NumKeyValueHeads { get; set; } + + [JsonPropertyName("resid_pdrop")] + public double ResidPdrop { get; set; } + + [JsonPropertyName("embd_pdrop")] + public double EmbdPdrop { get; set; } + + [JsonPropertyName("attention_dropout")] + public double AttentionDropout { get; set; } + + [JsonPropertyName("hidden_act")] + public string HiddenAct { get; set; } + + [JsonPropertyName("max_position_embeddings")] + public int MaxPositionEmbeddings { get; set; } + + [JsonPropertyName("original_max_position_embeddings")] + public int OriginalMaxPositionEmbeddings { get; set; } + + [JsonPropertyName("initializer_range")] + public double InitializerRange { get; set; } + + [JsonPropertyName("use_cache")] + public bool UseCache { get; set; } + + [JsonPropertyName("tie_word_embeddings")] + public bool TieWordEmbeddings { get; set; } + + [JsonPropertyName("rope_theta")] + public double RopeTheta { get; set; } + + [JsonPropertyName("rope_scaling")] + public Dictionary? RopeScaling { get; set; } + + [JsonPropertyName("partial_rotary_factor")] + public double PartialRotaryFactor { get; set; } + + [JsonPropertyName("qk_layernorm")] + public bool QkLayernorm { get; set; } + + [JsonPropertyName("bos_token_id")] + public int BosTokenId { get; set; } + + [JsonPropertyName("eos_token_id")] + public int EosTokenId { get; set; } + + [JsonPropertyName("pad_token_id")] + public int? PadTokenId { get; set; } + + [JsonPropertyName("sliding_window")] + public int? SlidingWindow { get; set; } + + public ScalarType DType { get; set; } + + public string AttnImplementation { get; set; } +} diff --git a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ForCasualLM.cs b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ForCasualLM.cs new file mode 100644 index 0000000000..9992c92c30 --- /dev/null +++ b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ForCasualLM.cs @@ -0,0 +1,65 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Text.Json; +using System.Threading.Tasks; +using Microsoft.ML.GenAI.Core; +using Microsoft.ML.GenAI.Phi.Module; +using TorchSharp; +using TorchSharp.Modules; +using TorchSharp.PyBridge; +using static TorchSharp.torch; + +namespace Microsoft.ML.GenAI.Phi; + +public class Phi3ForCasualLM : nn.Module +{ + private readonly Phi3Config _config; + +#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format + private readonly Phi3Model model; + private readonly GenAILinear lm_head; +#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format + + public Phi3ForCasualLM(Phi3Config config) + : base(nameof(Phi3ForCasualLM)) + { + this._config = config; + this.model = new Phi3Model(config); + this.lm_head = new GenAILinear(config.HiddenSize, config.VocabSize, dtype: config.DType, hasBias: false); + + this.RegisterComponents(); + } + +#pragma warning disable MSML_GeneralName // This name should be PascalCased + public override CasualLMModelOutput forward(CasualLMModelInput input) +#pragma warning restore MSML_GeneralName // This name should be PascalCased + { + var outputs = this.model.forward(input); + var logits = this.lm_head.forward(outputs.LastHiddenState); + logits = logits.to_type(ScalarType.Float32); + outputs.Logits = logits; + + return outputs; + } + + public static Phi3ForCasualLM FromPretrained( + string modelFolder, + string configName = "config.json", + string checkPointName = "model.safetensors.index.json", + ScalarType torchDtype = ScalarType.BFloat16, + string device = "cpu") + { + var config = Path.Join(modelFolder, configName); + var modelConfig = JsonSerializer.Deserialize(File.ReadAllText(config)) ?? throw new ArgumentNullException(nameof(config)); + modelConfig.DType = torchDtype; + var phi = new Phi3ForCasualLM(modelConfig); + var loadedParameters = new Dictionary(); + phi.load_checkpoint(path: modelFolder, checkpointName: checkPointName, strict: false, loadedParameters: loadedParameters); + phi = phi.to(device); + phi.eval(); + + return phi; + } +} diff --git a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3Tokenzier.cs b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3Tokenzier.cs new file mode 100644 index 0000000000..8ef4f6fbde --- /dev/null +++ b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3Tokenzier.cs @@ -0,0 +1,222 @@ +using Microsoft.ML.Tokenizers; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text; +using System.Text.Json; +using System.Text.RegularExpressions; +using System.Threading.Tasks; + +namespace Microsoft.ML.GenAI.Phi; +public interface ITokenizer +{ + public int BosId { get; } + + public int EosId { get; } + + public string Decode(int[] input); + + public int[] Encode(string input, bool bos, bool eos); +} + +/// +/// Copied from https://github.com/LittleLittleCloud/Torchsharp-llama/blob/main/ITokenizer.cs +/// +public class LLama2Tokenizer : ITokenizer +{ + private readonly SentencePieceBpe _tokenizer; + private readonly bool _addPrecedingSpace; + private const string SystemSymbol = "<|system|>"; + private const string UserSymbol = "<|user|>"; + private const string AssistantSymbol = "<|assistant|>"; + private const string EndSymbol = "<|end|>"; + private const int SystemSymbolId = 32006; + private const int UserSymbolId = 32010; + private const int AssistantSymbolId = 32001; + private const int EndSymbolId = 32007; + private readonly Dictionary _specialTokenMap = new Dictionary + { + { SystemSymbol, SystemSymbolId }, + { UserSymbol, UserSymbolId }, + { AssistantSymbol, AssistantSymbolId }, + { EndSymbol, EndSymbolId } + }; + + public LLama2Tokenizer(string modelPath, bool addPrecedingSpace = true) + { + var modelStream = File.OpenRead(modelPath); + this._addPrecedingSpace = addPrecedingSpace; + this._tokenizer = (SentencePieceBpe)Tokenizer.CreateLlama(modelStream, false, false); + + // use reflection to set the readonly ByteFallback property to false + //var backingField = typeof(SentencePieceBpe).GetField("k__BackingField", BindingFlags.NonPublic | BindingFlags.Instance); + //backingField.SetValue(this.tokenizer, false); + } + //public LLama2Tokenizer(string vocabPath, string mergesPath, bool addPrecedingSpace = true, int padToken = -1, int startToken = 1, int endToken = 2) + //{ + // this.BosId = startToken; + // this.EosId = endToken; + // this.addPrecedingSpace = addPrecedingSpace; + // this.PadId = padToken; + // var bpe = new Bpe(vocabPath, mergesPath); + // this.tokenizer = new Tokenizer(bpe, preTokenizer: new PreTokenizer(), normalizer: new Norm()); + // var decoder = new TokenizeDecoder(this.tokenizer.Model.IdToToken(this.BosId)!, this.tokenizer.Model.IdToToken(this.EosId)!); + // this.tokenizer.Decoder = decoder; + //} + + //public LLama2Tokenizer(Dictionary vocab, List merges, bool addPrecedingSpace = true, int padToken = -1, int startToken = 1, int endToken = 2) + //{ + // this.BosId = startToken; + // this.EosId = endToken; + // this.addPrecedingSpace = addPrecedingSpace; + // this.PadId = padToken; + // // save vocab to vocab-temp.json + // var vocabTempPath = "vocab-temp.json"; + // var json = JsonSerializer.Serialize(vocab); + // File.WriteAllText(vocabTempPath, json); + + // // save merges to merges-temp.txt + // var mergesTempPath = "merges-temp.txt"; + // // filter out merges that contain newline character because it will cause error in BPE + // merges = merges.Where(x => !x.Contains('\r')).ToList(); + // File.WriteAllLines(mergesTempPath, merges); + + // var bpe = new Bpe(vocabTempPath, mergesTempPath); + // this.tokenizer = new Tokenizer(bpe, preTokenizer: new PreTokenizer(), normalizer: new Norm()); + // var decoder = new TokenizeDecoder(this.tokenizer.Model.IdToToken(this.BosId)!, this.tokenizer.Model.IdToToken(this.EosId)!); + // this.tokenizer.Decoder = decoder; + + // // delete temp files + // File.Delete(vocabTempPath); + // File.Delete(mergesTempPath); + //} + + public static LLama2Tokenizer FromPretrained( + string folder, + string modelName = "tokenizer.model") + { + return new LLama2Tokenizer(Path.Combine(folder, modelName)); + } + + //public static LLama2Tokenizer FromPretrained( + // string folder, + // string tokenizerJsonPath = "tokenizer.json", + // string specialTokensMapPath = "special_tokens_map.json" + //) + //{ + // tokenizerJsonPath = Path.Combine(folder, tokenizerJsonPath); + // var json = File.ReadAllText(tokenizerJsonPath); + // var jsonDocument = JsonDocument.Parse(json); + // // vocab: .model.vocab + // var vocabNode = jsonDocument.RootElement.GetProperty("model").GetProperty("vocab"); + + // // to Dictionary + // var vocab = new Dictionary(); + // foreach (var item in vocabNode.EnumerateObject()) + // { + // vocab[item.Name] = item.Value.GetInt32(); + // } + + // // added tokens: .added_tokens + // var addedTokensNode = jsonDocument.RootElement.GetProperty("added_tokens"); + // foreach (var item in addedTokensNode.EnumerateArray()) + // { + // // get id from item.id + // var id = item.GetProperty("id").GetInt32(); + // var content = item.GetProperty("content").GetString()!; + // vocab[content] = id; + // } + + // // merges: .model.merges + // var mergesNode = jsonDocument.RootElement.GetProperty("model").GetProperty("merges"); + // // merges: List + // var merges = new List(); + // foreach (var item in mergesNode.EnumerateArray()) + // { + // merges.Add(item.GetString()!); + // } + + // int startToken = 1, endToken = 2, padToken = -1; + // var specialTokenJsonPath = Path.Combine(folder, specialTokensMapPath); + // if (File.Exists(specialTokenJsonPath)) + // { + // var specialTokenJson = File.ReadAllText(specialTokenJsonPath); + // var specialTokenMapDocument = JsonDocument.Parse(specialTokenJson); + + // // retrieve bos_token, eos_token, pad_token if exists + // if (specialTokenMapDocument.RootElement.TryGetProperty("bos_token", out var bosTokenNode)) + // { + // var bos_token_content = bosTokenNode.GetProperty("content").GetString()!; + // startToken = vocab[bos_token_content]; + // } + + // if (specialTokenMapDocument.RootElement.TryGetProperty("eos_token", out var eosTokenNode)) + // { + // var eos_token_content = eosTokenNode.GetProperty("content").GetString()!; + // endToken = vocab[eos_token_content]; + // } + + // if (specialTokenMapDocument.RootElement.TryGetProperty("pad_token", out var padTokenNode)) + // { + // var pad_token_content = padTokenNode.GetProperty("content").GetString()!; + // padToken = vocab[pad_token_content]; + // } + // } + + // return new LLama2Tokenizer(vocab, merges, padToken: padToken, addPrecedingSpace: false, startToken: startToken, endToken: endToken); + //} + + //public int VocabSize => this.tokenizer..GetVocabSize(); + + public int PadId { get => this._tokenizer.UnknownId; } + + public int BosId { get => this._tokenizer.BeginningOfSentenceId; } + + public int EosId { get => this._tokenizer.EndOfSentenceId; } + + public string Decode(int[] input) + { + var str = this._tokenizer.Decode(input) ?? throw new Exception("Failed to decode"); + if (this._addPrecedingSpace) + { + str = str.TrimStart(); + } + + return str; + } + + public int[] Encode(string input, bool bos, bool eos) + { + // step 1: + // replace all special tokens to + var re = new Regex($"{SystemSymbol.Replace("|", "\\|")}|{UserSymbol.Replace("|", "\\|")}|{AssistantSymbol.Replace("|", "\\|")}|{EndSymbol.Replace("|", "\\|")}"); + var matches = re.Matches(input); + var matchesList = new List(); + var tokens = new List(); + foreach (Match match in matches) + { + // replace the first special tokens with + var specialToken = match.Value; + var index = input.IndexOf(specialToken); + var subString = input.Substring(0, index); + var subTokens = this._tokenizer.EncodeToIds(subString, addBeginningOfSentence: false, addEndOfSentence: false).ToArray(); + tokens.AddRange(subTokens); + tokens.Add(this._specialTokenMap[specialToken]); + input = input.Remove(0, index + specialToken.Length); + } + + tokens.AddRange(this._tokenizer.EncodeToIds(input, addBeginningOfSentence: false, addEndOfSentence: false).ToArray()); + if (bos) + { + tokens.Insert(0, this.BosId); + } + if (eos) + { + tokens.Add(this.EosId); + } + + + return tokens.ToArray(); + } +} diff --git a/src/Microsoft.ML.GenAI.Phi/Utils.cs b/src/Microsoft.ML.GenAI.Phi/Utils.cs new file mode 100644 index 0000000000..c4a05bbe40 --- /dev/null +++ b/src/Microsoft.ML.GenAI.Phi/Utils.cs @@ -0,0 +1,159 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using TorchSharp.Modules; +using TorchSharp; +using static TorchSharp.torch; +using static TorchSharp.torch.nn; +namespace Microsoft.ML.GenAI.Phi; + +public static class Utils +{ + public static Tensor ApplyRotaryEmbeddings(Tensor input, Tensor freqsComplex) + { + // Separate the last dimension pairs of two values, representing the real and imaginary parts of the complex number + // Two consecutive values will become a single complex number + // (B, Seq_Len, H, Head_Dim) -> (B, Seq_Len, H, Head_Dim/2) + var inputComplex = input.to_type(ScalarType.Float32).reshape(input.shape[0], input.shape[1], input.shape[2], -1, 2).view_as_complex(); + freqsComplex = freqsComplex.to(input.device); + + // Reshape the freqs_complex tensor to match the shape of the x_complex tensor. So we need to add the batch dimension and the head dimension + // (Seq_Len, Head_Dim/2) --> (1, Seq_Len, 1, Head_Dim/2) + var freqsComplexReshaped = freqsComplex.unsqueeze(0).unsqueeze(2); + + // Multiply each complex number in the x_complex tensor by the corresponding complex number in the freqs_complex tensor + // Which results in the rotation of the complex number as shown in the Figure 1 of the paper + // (B, Seq_Len, H, Head_Dim/2) * (1, Seq_Len, 1, Head_Dim/2) = (B, Seq_Len, H, Head_Dim/2) + var rotatedComplex = inputComplex * freqsComplexReshaped; + // Console.WriteLine(rotated_complex.mean().ToSingle()); + + // Convert the complex number back to the real number + // (B, Seq_Len, H, Head_Dim/2) -> (B, Seq_Len, H, Head_Dim/2, 2) + var rotated = rotatedComplex.view_as_real(); + + // (B, Seq_Len, H, Head_Dim/2, 2) -> (B, Seq_Len, H, Head_Dim) + var rotatedReshaped = rotated.reshape(rotated.shape[0], rotated.shape[1], rotated.shape[2], -1); + + return rotatedReshaped.type_as(input); + } + + public static Tensor PrecomputeThetaPosFrequencies(int headDim, int seqLen, string device, float theta = 10000.0f) + { + // As written in the paragraph 3.2.2 of the paper + // >> In order to generalize our results in 2D to any xi ∈ Rd where **d is even**, [...] + if (headDim % 2 != 0) + { + throw new ArgumentException("Dimension must be divisible by 2", nameof(headDim)); + } + + // Build the theta parameter + // According to the formula theta_i = 10000^(-2(i-1)/dim) for i = [1, 2, ... dim/2] + // Shape: (Head_Dim / 2) + var thetaNumerator = torch.arange(0, headDim, 2).to(torch.float32).to(device); + // Shape: (Head_Dim / 2) + var thetaInput = torch.pow(theta, -1.0f * (thetaNumerator / headDim)).to(device); // (Dim / 2) + // Construct the positions (the "m" parameter) + // Shape: (Seq_Len) + var m = torch.arange(seqLen, device: device); + // Multiply each theta by each position using the outer product. + // Shape: (Seq_Len) outer_product* (Head_Dim / 2) -> (Seq_Len, Head_Dim / 2) + var freqs = torch.outer(m, thetaInput).to(torch.float32).to(device); + + // We can compute complex numbers in the polar form c = R * exp(m * theta), where R = 1 as follows: + // (Seq_Len, Head_Dim / 2) -> (Seq_Len, Head_Dim / 2) + var freqsComplex = torch.polar(torch.ones_like(freqs), freqs); + + return freqsComplex; + } + + // python + // def rotate_half(x): + // """Rotates half the hidden dims of the input.""" + // x1 = x[..., : x.shape[-1] // 2] + // x2 = x[..., x.shape[-1] // 2 :] + // return torch.cat((-x2, x1), dim=-1) + public static Tensor RotateHalf(Tensor x) + { + var x1 = x[.., .., .., ..(int)(x.shape[^1] / 2)]; + var x2 = x[.., .., .., (int)(x.shape[^1] / 2)..]; + // (x1 * x1 * x2).Peek("x1 * x1 * x2"); + return torch.cat([-x2, x1], dim: -1); + } + + public static (Tensor, Tensor) ApplyRotaryPosEmb(Tensor q, Tensor k, Tensor cos, Tensor sin, Tensor? positionIds = null, int unsqueezeDim = 1) + { + // The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + // sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + // that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + // k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + // cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + // the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + + if (positionIds is not null) + { + cos = cos[positionIds!].unsqueeze(unsqueezeDim); + sin = sin[positionIds!].unsqueeze(unsqueezeDim); + } + else + { + cos = cos.unsqueeze(unsqueezeDim); + sin = sin.unsqueeze(unsqueezeDim); + } + var qEmbed = q * cos; + qEmbed += RotateHalf(q) * sin; + + var kEmbed = k * cos; + kEmbed += RotateHalf(k) * sin; + // var kEmbed = (k * cos) + (RotateHalf(k) * sin); + return (qEmbed, kEmbed); + } + + public static Module GetActivation(string actFn) + { + return actFn switch + { + "silu" => nn.SiLU(), + "relu" => nn.ReLU(), + "gelu" => nn.GELU(), + "tanh" => nn.Tanh(), + "swish" => nn.SiLU(), + _ => throw new ArgumentException("Invalid activation function", actFn), + }; + } + + + public static Tensor Phi2RepeatKV(Tensor x, int nRep) + { + var batchSize = x.shape[0]; + var seqLen = x.shape[1]; + var nKVHeads = x.shape[2]; + var headDim = x.shape[3]; + if (nRep == 1) + { + return x; + } + + return x.unsqueeze(3) + .expand(batchSize, seqLen, nKVHeads, nRep, headDim) + .view(batchSize, seqLen, nKVHeads * nRep, headDim); + } + + public static Tensor Phi3RepeatKV(Tensor x, int nRep) + { + var batchSize = x.shape[0]; + var nKVHeads = x.shape[1]; + var seqLen = x.shape[2]; + var headDim = x.shape[3]; + if (nRep == 1) + { + return x; + } + + return x.unsqueeze(3) + .expand(batchSize, nKVHeads, nRep, seqLen, headDim) + .view(batchSize, nKVHeads * nRep, seqLen, headDim); + } + +} diff --git a/src/Microsoft.ML.Tokenizers/Tokenizer.cs b/src/Microsoft.ML.Tokenizers/Tokenizer.cs index 7a31f5e37f..3c01f6ec1e 100644 --- a/src/Microsoft.ML.Tokenizers/Tokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Tokenizer.cs @@ -527,6 +527,24 @@ public static Tokenizer CreatePhi2( bool addEndOfSentence = false) => CreateCodeGen(vocabStream, mergesStream, addPrefixSpace, addBeginOfSentence, addEndOfSentence); + public static CodeGen CreatePhi2( + string folder, + string vocabFile = "vocab.json", + string mergesFile = "merges.txt", + string specialTokensFile = "special_tokens_map.json", + bool addPrefixSpace = false, + bool addBeginOfSentence = false, + bool addEndOfSentence = false) + { + var vocabPath = Path.Combine(folder, vocabFile); + var mergesPath = Path.Combine(folder, mergesFile); + var specialTokenMapPath = Path.Combine(folder, specialTokensFile); + using var vocabStream = File.OpenRead(vocabPath); + using var mergesStream = File.OpenRead(mergesPath); + + return (CodeGen)CreateCodeGen(vocabStream, mergesStream, addPrefixSpace, addBeginOfSentence, addEndOfSentence); + } + internal static IEnumerable<(int Offset, int Length)>? InitializeForEncoding( string? text, ReadOnlySpan textSpan, diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi2Test.LoadSafeTensorShapeTest.approved.txt b/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi2Test.LoadSafeTensorShapeTest.approved.txt new file mode 100644 index 0000000000..75e17ad1a6 --- /dev/null +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi2Test.LoadSafeTensorShapeTest.approved.txt @@ -0,0 +1,453 @@ +0: lm_head.bias shape: [51200] +1: lm_head.weight shape: [51200, 2560] +2: model.embed_tokens.weight shape: [51200, 2560] +3: model.final_layernorm.bias shape: [2560] +4: model.final_layernorm.weight shape: [2560] +5: model.layers.0.input_layernorm.bias shape: [2560] +6: model.layers.0.input_layernorm.weight shape: [2560] +7: model.layers.0.mlp.fc1.bias shape: [10240] +8: model.layers.0.mlp.fc1.weight shape: [10240, 2560] +9: model.layers.0.mlp.fc2.bias shape: [2560] +10: model.layers.0.mlp.fc2.weight shape: [2560, 10240] +11: model.layers.0.self_attn.dense.bias shape: [2560] +12: model.layers.0.self_attn.dense.weight shape: [2560, 2560] +13: model.layers.0.self_attn.k_proj.bias shape: [2560] +14: model.layers.0.self_attn.k_proj.weight shape: [2560, 2560] +15: model.layers.0.self_attn.q_proj.bias shape: [2560] +16: model.layers.0.self_attn.q_proj.weight shape: [2560, 2560] +17: model.layers.0.self_attn.v_proj.bias shape: [2560] +18: model.layers.0.self_attn.v_proj.weight shape: [2560, 2560] +19: model.layers.1.input_layernorm.bias shape: [2560] +20: model.layers.1.input_layernorm.weight shape: [2560] +21: model.layers.1.mlp.fc1.bias shape: [10240] +22: model.layers.1.mlp.fc1.weight shape: [10240, 2560] +23: model.layers.1.mlp.fc2.bias shape: [2560] +24: model.layers.1.mlp.fc2.weight shape: [2560, 10240] +25: model.layers.1.self_attn.dense.bias shape: [2560] +26: model.layers.1.self_attn.dense.weight shape: [2560, 2560] +27: model.layers.1.self_attn.k_proj.bias shape: [2560] +28: model.layers.1.self_attn.k_proj.weight shape: [2560, 2560] +29: model.layers.1.self_attn.q_proj.bias shape: [2560] +30: model.layers.1.self_attn.q_proj.weight shape: [2560, 2560] +31: model.layers.1.self_attn.v_proj.bias shape: [2560] +32: model.layers.1.self_attn.v_proj.weight shape: [2560, 2560] +33: model.layers.10.input_layernorm.bias shape: [2560] +34: model.layers.10.input_layernorm.weight shape: [2560] +35: model.layers.10.mlp.fc1.bias shape: [10240] +36: model.layers.10.mlp.fc1.weight shape: [10240, 2560] +37: model.layers.10.mlp.fc2.bias shape: [2560] +38: model.layers.10.mlp.fc2.weight shape: [2560, 10240] +39: model.layers.10.self_attn.dense.bias shape: [2560] +40: model.layers.10.self_attn.dense.weight shape: [2560, 2560] +41: model.layers.10.self_attn.k_proj.bias shape: [2560] +42: model.layers.10.self_attn.k_proj.weight shape: [2560, 2560] +43: model.layers.10.self_attn.q_proj.bias shape: [2560] +44: model.layers.10.self_attn.q_proj.weight shape: [2560, 2560] +45: model.layers.10.self_attn.v_proj.bias shape: [2560] +46: model.layers.10.self_attn.v_proj.weight shape: [2560, 2560] +47: model.layers.11.input_layernorm.bias shape: [2560] +48: model.layers.11.input_layernorm.weight shape: [2560] +49: model.layers.11.mlp.fc1.bias shape: [10240] +50: model.layers.11.mlp.fc1.weight shape: [10240, 2560] +51: model.layers.11.mlp.fc2.bias shape: [2560] +52: model.layers.11.mlp.fc2.weight shape: [2560, 10240] +53: model.layers.11.self_attn.dense.bias shape: [2560] +54: model.layers.11.self_attn.dense.weight shape: [2560, 2560] +55: model.layers.11.self_attn.k_proj.bias shape: [2560] +56: model.layers.11.self_attn.k_proj.weight shape: [2560, 2560] +57: model.layers.11.self_attn.q_proj.bias shape: [2560] +58: model.layers.11.self_attn.q_proj.weight shape: [2560, 2560] +59: model.layers.11.self_attn.v_proj.bias shape: [2560] +60: model.layers.11.self_attn.v_proj.weight shape: [2560, 2560] +61: model.layers.12.input_layernorm.bias shape: [2560] +62: model.layers.12.input_layernorm.weight shape: [2560] +63: model.layers.12.mlp.fc1.bias shape: [10240] +64: model.layers.12.mlp.fc1.weight shape: [10240, 2560] +65: model.layers.12.mlp.fc2.bias shape: [2560] +66: model.layers.12.mlp.fc2.weight shape: [2560, 10240] +67: model.layers.12.self_attn.dense.bias shape: [2560] +68: model.layers.12.self_attn.dense.weight shape: [2560, 2560] +69: model.layers.12.self_attn.k_proj.bias shape: [2560] +70: model.layers.12.self_attn.k_proj.weight shape: [2560, 2560] +71: model.layers.12.self_attn.q_proj.bias shape: [2560] +72: model.layers.12.self_attn.q_proj.weight shape: [2560, 2560] +73: model.layers.12.self_attn.v_proj.bias shape: [2560] +74: model.layers.12.self_attn.v_proj.weight shape: [2560, 2560] +75: model.layers.13.input_layernorm.bias shape: [2560] +76: model.layers.13.input_layernorm.weight shape: [2560] +77: model.layers.13.mlp.fc1.bias shape: [10240] +78: model.layers.13.mlp.fc1.weight shape: [10240, 2560] +79: model.layers.13.mlp.fc2.bias shape: [2560] +80: model.layers.13.mlp.fc2.weight shape: [2560, 10240] +81: model.layers.13.self_attn.dense.bias shape: [2560] +82: model.layers.13.self_attn.dense.weight shape: [2560, 2560] +83: model.layers.13.self_attn.k_proj.bias shape: [2560] +84: model.layers.13.self_attn.k_proj.weight shape: [2560, 2560] +85: model.layers.13.self_attn.q_proj.bias shape: [2560] +86: model.layers.13.self_attn.q_proj.weight shape: [2560, 2560] +87: model.layers.13.self_attn.v_proj.bias shape: [2560] +88: model.layers.13.self_attn.v_proj.weight shape: [2560, 2560] +89: model.layers.14.input_layernorm.bias shape: [2560] +90: model.layers.14.input_layernorm.weight shape: [2560] +91: model.layers.14.mlp.fc1.bias shape: [10240] +92: model.layers.14.mlp.fc1.weight shape: [10240, 2560] +93: model.layers.14.mlp.fc2.bias shape: [2560] +94: model.layers.14.mlp.fc2.weight shape: [2560, 10240] +95: model.layers.14.self_attn.dense.bias shape: [2560] +96: model.layers.14.self_attn.dense.weight shape: [2560, 2560] +97: model.layers.14.self_attn.k_proj.bias shape: [2560] +98: model.layers.14.self_attn.k_proj.weight shape: [2560, 2560] +99: model.layers.14.self_attn.q_proj.bias shape: [2560] +100: model.layers.14.self_attn.q_proj.weight shape: [2560, 2560] +101: model.layers.14.self_attn.v_proj.bias shape: [2560] +102: model.layers.14.self_attn.v_proj.weight shape: [2560, 2560] +103: model.layers.15.input_layernorm.bias shape: [2560] +104: model.layers.15.input_layernorm.weight shape: [2560] +105: model.layers.15.mlp.fc1.bias shape: [10240] +106: model.layers.15.mlp.fc1.weight shape: [10240, 2560] +107: model.layers.15.mlp.fc2.bias shape: [2560] +108: model.layers.15.mlp.fc2.weight shape: [2560, 10240] +109: model.layers.15.self_attn.dense.bias shape: [2560] +110: model.layers.15.self_attn.dense.weight shape: [2560, 2560] +111: model.layers.15.self_attn.k_proj.bias shape: [2560] +112: model.layers.15.self_attn.k_proj.weight shape: [2560, 2560] +113: model.layers.15.self_attn.q_proj.bias shape: [2560] +114: model.layers.15.self_attn.q_proj.weight shape: [2560, 2560] +115: model.layers.15.self_attn.v_proj.bias shape: [2560] +116: model.layers.15.self_attn.v_proj.weight shape: [2560, 2560] +117: model.layers.16.input_layernorm.bias shape: [2560] +118: model.layers.16.input_layernorm.weight shape: [2560] +119: model.layers.16.mlp.fc1.bias shape: [10240] +120: model.layers.16.mlp.fc1.weight shape: [10240, 2560] +121: model.layers.16.mlp.fc2.bias shape: [2560] +122: model.layers.16.mlp.fc2.weight shape: [2560, 10240] +123: model.layers.16.self_attn.dense.bias shape: [2560] +124: model.layers.16.self_attn.dense.weight shape: [2560, 2560] +125: model.layers.16.self_attn.k_proj.bias shape: [2560] +126: model.layers.16.self_attn.k_proj.weight shape: [2560, 2560] +127: model.layers.16.self_attn.q_proj.bias shape: [2560] +128: model.layers.16.self_attn.q_proj.weight shape: [2560, 2560] +129: model.layers.16.self_attn.v_proj.bias shape: [2560] +130: model.layers.16.self_attn.v_proj.weight shape: [2560, 2560] +131: model.layers.17.input_layernorm.bias shape: [2560] +132: model.layers.17.input_layernorm.weight shape: [2560] +133: model.layers.17.mlp.fc1.bias shape: [10240] +134: model.layers.17.mlp.fc1.weight shape: [10240, 2560] +135: model.layers.17.mlp.fc2.bias shape: [2560] +136: model.layers.17.mlp.fc2.weight shape: [2560, 10240] +137: model.layers.17.self_attn.dense.bias shape: [2560] +138: model.layers.17.self_attn.dense.weight shape: [2560, 2560] +139: model.layers.17.self_attn.k_proj.bias shape: [2560] +140: model.layers.17.self_attn.k_proj.weight shape: [2560, 2560] +141: model.layers.17.self_attn.q_proj.bias shape: [2560] +142: model.layers.17.self_attn.q_proj.weight shape: [2560, 2560] +143: model.layers.17.self_attn.v_proj.bias shape: [2560] +144: model.layers.17.self_attn.v_proj.weight shape: [2560, 2560] +145: model.layers.18.input_layernorm.bias shape: [2560] +146: model.layers.18.input_layernorm.weight shape: [2560] +147: model.layers.18.mlp.fc1.bias shape: [10240] +148: model.layers.18.mlp.fc1.weight shape: [10240, 2560] +149: model.layers.18.mlp.fc2.bias shape: [2560] +150: model.layers.18.mlp.fc2.weight shape: [2560, 10240] +151: model.layers.18.self_attn.dense.bias shape: [2560] +152: model.layers.18.self_attn.dense.weight shape: [2560, 2560] +153: model.layers.18.self_attn.k_proj.bias shape: [2560] +154: model.layers.18.self_attn.k_proj.weight shape: [2560, 2560] +155: model.layers.18.self_attn.q_proj.bias shape: [2560] +156: model.layers.18.self_attn.q_proj.weight shape: [2560, 2560] +157: model.layers.18.self_attn.v_proj.bias shape: [2560] +158: model.layers.18.self_attn.v_proj.weight shape: [2560, 2560] +159: model.layers.19.input_layernorm.bias shape: [2560] +160: model.layers.19.input_layernorm.weight shape: [2560] +161: model.layers.19.mlp.fc1.bias shape: [10240] +162: model.layers.19.mlp.fc1.weight shape: [10240, 2560] +163: model.layers.19.mlp.fc2.bias shape: [2560] +164: model.layers.19.mlp.fc2.weight shape: [2560, 10240] +165: model.layers.19.self_attn.dense.bias shape: [2560] +166: model.layers.19.self_attn.dense.weight shape: [2560, 2560] +167: model.layers.19.self_attn.k_proj.bias shape: [2560] +168: model.layers.19.self_attn.k_proj.weight shape: [2560, 2560] +169: model.layers.19.self_attn.q_proj.bias shape: [2560] +170: model.layers.19.self_attn.q_proj.weight shape: [2560, 2560] +171: model.layers.19.self_attn.v_proj.bias shape: [2560] +172: model.layers.19.self_attn.v_proj.weight shape: [2560, 2560] +173: model.layers.2.input_layernorm.bias shape: [2560] +174: model.layers.2.input_layernorm.weight shape: [2560] +175: model.layers.2.mlp.fc1.bias shape: [10240] +176: model.layers.2.mlp.fc1.weight shape: [10240, 2560] +177: model.layers.2.mlp.fc2.bias shape: [2560] +178: model.layers.2.mlp.fc2.weight shape: [2560, 10240] +179: model.layers.2.self_attn.dense.bias shape: [2560] +180: model.layers.2.self_attn.dense.weight shape: [2560, 2560] +181: model.layers.2.self_attn.k_proj.bias shape: [2560] +182: model.layers.2.self_attn.k_proj.weight shape: [2560, 2560] +183: model.layers.2.self_attn.q_proj.bias shape: [2560] +184: model.layers.2.self_attn.q_proj.weight shape: [2560, 2560] +185: model.layers.2.self_attn.v_proj.bias shape: [2560] +186: model.layers.2.self_attn.v_proj.weight shape: [2560, 2560] +187: model.layers.20.input_layernorm.bias shape: [2560] +188: model.layers.20.input_layernorm.weight shape: [2560] +189: model.layers.20.mlp.fc1.bias shape: [10240] +190: model.layers.20.mlp.fc1.weight shape: [10240, 2560] +191: model.layers.20.mlp.fc2.bias shape: [2560] +192: model.layers.20.mlp.fc2.weight shape: [2560, 10240] +193: model.layers.20.self_attn.dense.bias shape: [2560] +194: model.layers.20.self_attn.dense.weight shape: [2560, 2560] +195: model.layers.20.self_attn.k_proj.bias shape: [2560] +196: model.layers.20.self_attn.k_proj.weight shape: [2560, 2560] +197: model.layers.20.self_attn.q_proj.bias shape: [2560] +198: model.layers.20.self_attn.q_proj.weight shape: [2560, 2560] +199: model.layers.20.self_attn.v_proj.bias shape: [2560] +200: model.layers.20.self_attn.v_proj.weight shape: [2560, 2560] +201: model.layers.21.input_layernorm.bias shape: [2560] +202: model.layers.21.input_layernorm.weight shape: [2560] +203: model.layers.21.mlp.fc1.bias shape: [10240] +204: model.layers.21.mlp.fc1.weight shape: [10240, 2560] +205: model.layers.21.mlp.fc2.bias shape: [2560] +206: model.layers.21.mlp.fc2.weight shape: [2560, 10240] +207: model.layers.21.self_attn.dense.bias shape: [2560] +208: model.layers.21.self_attn.dense.weight shape: [2560, 2560] +209: model.layers.21.self_attn.k_proj.bias shape: [2560] +210: model.layers.21.self_attn.k_proj.weight shape: [2560, 2560] +211: model.layers.21.self_attn.q_proj.bias shape: [2560] +212: model.layers.21.self_attn.q_proj.weight shape: [2560, 2560] +213: model.layers.21.self_attn.v_proj.bias shape: [2560] +214: model.layers.21.self_attn.v_proj.weight shape: [2560, 2560] +215: model.layers.22.input_layernorm.bias shape: [2560] +216: model.layers.22.input_layernorm.weight shape: [2560] +217: model.layers.22.mlp.fc1.bias shape: [10240] +218: model.layers.22.mlp.fc1.weight shape: [10240, 2560] +219: model.layers.22.mlp.fc2.bias shape: [2560] +220: model.layers.22.mlp.fc2.weight shape: [2560, 10240] +221: model.layers.22.self_attn.dense.bias shape: [2560] +222: model.layers.22.self_attn.dense.weight shape: [2560, 2560] +223: model.layers.22.self_attn.k_proj.bias shape: [2560] +224: model.layers.22.self_attn.k_proj.weight shape: [2560, 2560] +225: model.layers.22.self_attn.q_proj.bias shape: [2560] +226: model.layers.22.self_attn.q_proj.weight shape: [2560, 2560] +227: model.layers.22.self_attn.v_proj.bias shape: [2560] +228: model.layers.22.self_attn.v_proj.weight shape: [2560, 2560] +229: model.layers.23.input_layernorm.bias shape: [2560] +230: model.layers.23.input_layernorm.weight shape: [2560] +231: model.layers.23.mlp.fc1.bias shape: [10240] +232: model.layers.23.mlp.fc1.weight shape: [10240, 2560] +233: model.layers.23.mlp.fc2.bias shape: [2560] +234: model.layers.23.mlp.fc2.weight shape: [2560, 10240] +235: model.layers.23.self_attn.dense.bias shape: [2560] +236: model.layers.23.self_attn.dense.weight shape: [2560, 2560] +237: model.layers.23.self_attn.k_proj.bias shape: [2560] +238: model.layers.23.self_attn.k_proj.weight shape: [2560, 2560] +239: model.layers.23.self_attn.q_proj.bias shape: [2560] +240: model.layers.23.self_attn.q_proj.weight shape: [2560, 2560] +241: model.layers.23.self_attn.v_proj.bias shape: [2560] +242: model.layers.23.self_attn.v_proj.weight shape: [2560, 2560] +243: model.layers.24.input_layernorm.bias shape: [2560] +244: model.layers.24.input_layernorm.weight shape: [2560] +245: model.layers.24.mlp.fc1.bias shape: [10240] +246: model.layers.24.mlp.fc1.weight shape: [10240, 2560] +247: model.layers.24.mlp.fc2.bias shape: [2560] +248: model.layers.24.mlp.fc2.weight shape: [2560, 10240] +249: model.layers.24.self_attn.dense.bias shape: [2560] +250: model.layers.24.self_attn.dense.weight shape: [2560, 2560] +251: model.layers.24.self_attn.k_proj.bias shape: [2560] +252: model.layers.24.self_attn.k_proj.weight shape: [2560, 2560] +253: model.layers.24.self_attn.q_proj.bias shape: [2560] +254: model.layers.24.self_attn.q_proj.weight shape: [2560, 2560] +255: model.layers.24.self_attn.v_proj.bias shape: [2560] +256: model.layers.24.self_attn.v_proj.weight shape: [2560, 2560] +257: model.layers.25.input_layernorm.bias shape: [2560] +258: model.layers.25.input_layernorm.weight shape: [2560] +259: model.layers.25.mlp.fc1.bias shape: [10240] +260: model.layers.25.mlp.fc1.weight shape: [10240, 2560] +261: model.layers.25.mlp.fc2.bias shape: [2560] +262: model.layers.25.mlp.fc2.weight shape: [2560, 10240] +263: model.layers.25.self_attn.dense.bias shape: [2560] +264: model.layers.25.self_attn.dense.weight shape: [2560, 2560] +265: model.layers.25.self_attn.k_proj.bias shape: [2560] +266: model.layers.25.self_attn.k_proj.weight shape: [2560, 2560] +267: model.layers.25.self_attn.q_proj.bias shape: [2560] +268: model.layers.25.self_attn.q_proj.weight shape: [2560, 2560] +269: model.layers.25.self_attn.v_proj.bias shape: [2560] +270: model.layers.25.self_attn.v_proj.weight shape: [2560, 2560] +271: model.layers.26.input_layernorm.bias shape: [2560] +272: model.layers.26.input_layernorm.weight shape: [2560] +273: model.layers.26.mlp.fc1.bias shape: [10240] +274: model.layers.26.mlp.fc1.weight shape: [10240, 2560] +275: model.layers.26.mlp.fc2.bias shape: [2560] +276: model.layers.26.mlp.fc2.weight shape: [2560, 10240] +277: model.layers.26.self_attn.dense.bias shape: [2560] +278: model.layers.26.self_attn.dense.weight shape: [2560, 2560] +279: model.layers.26.self_attn.k_proj.bias shape: [2560] +280: model.layers.26.self_attn.k_proj.weight shape: [2560, 2560] +281: model.layers.26.self_attn.q_proj.bias shape: [2560] +282: model.layers.26.self_attn.q_proj.weight shape: [2560, 2560] +283: model.layers.26.self_attn.v_proj.bias shape: [2560] +284: model.layers.26.self_attn.v_proj.weight shape: [2560, 2560] +285: model.layers.27.input_layernorm.bias shape: [2560] +286: model.layers.27.input_layernorm.weight shape: [2560] +287: model.layers.27.mlp.fc1.bias shape: [10240] +288: model.layers.27.mlp.fc1.weight shape: [10240, 2560] +289: model.layers.27.mlp.fc2.bias shape: [2560] +290: model.layers.27.mlp.fc2.weight shape: [2560, 10240] +291: model.layers.27.self_attn.dense.bias shape: [2560] +292: model.layers.27.self_attn.dense.weight shape: [2560, 2560] +293: model.layers.27.self_attn.k_proj.bias shape: [2560] +294: model.layers.27.self_attn.k_proj.weight shape: [2560, 2560] +295: model.layers.27.self_attn.q_proj.bias shape: [2560] +296: model.layers.27.self_attn.q_proj.weight shape: [2560, 2560] +297: model.layers.27.self_attn.v_proj.bias shape: [2560] +298: model.layers.27.self_attn.v_proj.weight shape: [2560, 2560] +299: model.layers.28.input_layernorm.bias shape: [2560] +300: model.layers.28.input_layernorm.weight shape: [2560] +301: model.layers.28.mlp.fc1.bias shape: [10240] +302: model.layers.28.mlp.fc1.weight shape: [10240, 2560] +303: model.layers.28.mlp.fc2.bias shape: [2560] +304: model.layers.28.mlp.fc2.weight shape: [2560, 10240] +305: model.layers.28.self_attn.dense.bias shape: [2560] +306: model.layers.28.self_attn.dense.weight shape: [2560, 2560] +307: model.layers.28.self_attn.k_proj.bias shape: [2560] +308: model.layers.28.self_attn.k_proj.weight shape: [2560, 2560] +309: model.layers.28.self_attn.q_proj.bias shape: [2560] +310: model.layers.28.self_attn.q_proj.weight shape: [2560, 2560] +311: model.layers.28.self_attn.v_proj.bias shape: [2560] +312: model.layers.28.self_attn.v_proj.weight shape: [2560, 2560] +313: model.layers.29.input_layernorm.bias shape: [2560] +314: model.layers.29.input_layernorm.weight shape: [2560] +315: model.layers.29.mlp.fc1.bias shape: [10240] +316: model.layers.29.mlp.fc1.weight shape: [10240, 2560] +317: model.layers.29.mlp.fc2.bias shape: [2560] +318: model.layers.29.mlp.fc2.weight shape: [2560, 10240] +319: model.layers.29.self_attn.dense.bias shape: [2560] +320: model.layers.29.self_attn.dense.weight shape: [2560, 2560] +321: model.layers.29.self_attn.k_proj.bias shape: [2560] +322: model.layers.29.self_attn.k_proj.weight shape: [2560, 2560] +323: model.layers.29.self_attn.q_proj.bias shape: [2560] +324: model.layers.29.self_attn.q_proj.weight shape: [2560, 2560] +325: model.layers.29.self_attn.v_proj.bias shape: [2560] +326: model.layers.29.self_attn.v_proj.weight shape: [2560, 2560] +327: model.layers.3.input_layernorm.bias shape: [2560] +328: model.layers.3.input_layernorm.weight shape: [2560] +329: model.layers.3.mlp.fc1.bias shape: [10240] +330: model.layers.3.mlp.fc1.weight shape: [10240, 2560] +331: model.layers.3.mlp.fc2.bias shape: [2560] +332: model.layers.3.mlp.fc2.weight shape: [2560, 10240] +333: model.layers.3.self_attn.dense.bias shape: [2560] +334: model.layers.3.self_attn.dense.weight shape: [2560, 2560] +335: model.layers.3.self_attn.k_proj.bias shape: [2560] +336: model.layers.3.self_attn.k_proj.weight shape: [2560, 2560] +337: model.layers.3.self_attn.q_proj.bias shape: [2560] +338: model.layers.3.self_attn.q_proj.weight shape: [2560, 2560] +339: model.layers.3.self_attn.v_proj.bias shape: [2560] +340: model.layers.3.self_attn.v_proj.weight shape: [2560, 2560] +341: model.layers.30.input_layernorm.bias shape: [2560] +342: model.layers.30.input_layernorm.weight shape: [2560] +343: model.layers.30.mlp.fc1.bias shape: [10240] +344: model.layers.30.mlp.fc1.weight shape: [10240, 2560] +345: model.layers.30.mlp.fc2.bias shape: [2560] +346: model.layers.30.mlp.fc2.weight shape: [2560, 10240] +347: model.layers.30.self_attn.dense.bias shape: [2560] +348: model.layers.30.self_attn.dense.weight shape: [2560, 2560] +349: model.layers.30.self_attn.k_proj.bias shape: [2560] +350: model.layers.30.self_attn.k_proj.weight shape: [2560, 2560] +351: model.layers.30.self_attn.q_proj.bias shape: [2560] +352: model.layers.30.self_attn.q_proj.weight shape: [2560, 2560] +353: model.layers.30.self_attn.v_proj.bias shape: [2560] +354: model.layers.30.self_attn.v_proj.weight shape: [2560, 2560] +355: model.layers.31.input_layernorm.bias shape: [2560] +356: model.layers.31.input_layernorm.weight shape: [2560] +357: model.layers.31.mlp.fc1.bias shape: [10240] +358: model.layers.31.mlp.fc1.weight shape: [10240, 2560] +359: model.layers.31.mlp.fc2.bias shape: [2560] +360: model.layers.31.mlp.fc2.weight shape: [2560, 10240] +361: model.layers.31.self_attn.dense.bias shape: [2560] +362: model.layers.31.self_attn.dense.weight shape: [2560, 2560] +363: model.layers.31.self_attn.k_proj.bias shape: [2560] +364: model.layers.31.self_attn.k_proj.weight shape: [2560, 2560] +365: model.layers.31.self_attn.q_proj.bias shape: [2560] +366: model.layers.31.self_attn.q_proj.weight shape: [2560, 2560] +367: model.layers.31.self_attn.v_proj.bias shape: [2560] +368: model.layers.31.self_attn.v_proj.weight shape: [2560, 2560] +369: model.layers.4.input_layernorm.bias shape: [2560] +370: model.layers.4.input_layernorm.weight shape: [2560] +371: model.layers.4.mlp.fc1.bias shape: [10240] +372: model.layers.4.mlp.fc1.weight shape: [10240, 2560] +373: model.layers.4.mlp.fc2.bias shape: [2560] +374: model.layers.4.mlp.fc2.weight shape: [2560, 10240] +375: model.layers.4.self_attn.dense.bias shape: [2560] +376: model.layers.4.self_attn.dense.weight shape: [2560, 2560] +377: model.layers.4.self_attn.k_proj.bias shape: [2560] +378: model.layers.4.self_attn.k_proj.weight shape: [2560, 2560] +379: model.layers.4.self_attn.q_proj.bias shape: [2560] +380: model.layers.4.self_attn.q_proj.weight shape: [2560, 2560] +381: model.layers.4.self_attn.v_proj.bias shape: [2560] +382: model.layers.4.self_attn.v_proj.weight shape: [2560, 2560] +383: model.layers.5.input_layernorm.bias shape: [2560] +384: model.layers.5.input_layernorm.weight shape: [2560] +385: model.layers.5.mlp.fc1.bias shape: [10240] +386: model.layers.5.mlp.fc1.weight shape: [10240, 2560] +387: model.layers.5.mlp.fc2.bias shape: [2560] +388: model.layers.5.mlp.fc2.weight shape: [2560, 10240] +389: model.layers.5.self_attn.dense.bias shape: [2560] +390: model.layers.5.self_attn.dense.weight shape: [2560, 2560] +391: model.layers.5.self_attn.k_proj.bias shape: [2560] +392: model.layers.5.self_attn.k_proj.weight shape: [2560, 2560] +393: model.layers.5.self_attn.q_proj.bias shape: [2560] +394: model.layers.5.self_attn.q_proj.weight shape: [2560, 2560] +395: model.layers.5.self_attn.v_proj.bias shape: [2560] +396: model.layers.5.self_attn.v_proj.weight shape: [2560, 2560] +397: model.layers.6.input_layernorm.bias shape: [2560] +398: model.layers.6.input_layernorm.weight shape: [2560] +399: model.layers.6.mlp.fc1.bias shape: [10240] +400: model.layers.6.mlp.fc1.weight shape: [10240, 2560] +401: model.layers.6.mlp.fc2.bias shape: [2560] +402: model.layers.6.mlp.fc2.weight shape: [2560, 10240] +403: model.layers.6.self_attn.dense.bias shape: [2560] +404: model.layers.6.self_attn.dense.weight shape: [2560, 2560] +405: model.layers.6.self_attn.k_proj.bias shape: [2560] +406: model.layers.6.self_attn.k_proj.weight shape: [2560, 2560] +407: model.layers.6.self_attn.q_proj.bias shape: [2560] +408: model.layers.6.self_attn.q_proj.weight shape: [2560, 2560] +409: model.layers.6.self_attn.v_proj.bias shape: [2560] +410: model.layers.6.self_attn.v_proj.weight shape: [2560, 2560] +411: model.layers.7.input_layernorm.bias shape: [2560] +412: model.layers.7.input_layernorm.weight shape: [2560] +413: model.layers.7.mlp.fc1.bias shape: [10240] +414: model.layers.7.mlp.fc1.weight shape: [10240, 2560] +415: model.layers.7.mlp.fc2.bias shape: [2560] +416: model.layers.7.mlp.fc2.weight shape: [2560, 10240] +417: model.layers.7.self_attn.dense.bias shape: [2560] +418: model.layers.7.self_attn.dense.weight shape: [2560, 2560] +419: model.layers.7.self_attn.k_proj.bias shape: [2560] +420: model.layers.7.self_attn.k_proj.weight shape: [2560, 2560] +421: model.layers.7.self_attn.q_proj.bias shape: [2560] +422: model.layers.7.self_attn.q_proj.weight shape: [2560, 2560] +423: model.layers.7.self_attn.v_proj.bias shape: [2560] +424: model.layers.7.self_attn.v_proj.weight shape: [2560, 2560] +425: model.layers.8.input_layernorm.bias shape: [2560] +426: model.layers.8.input_layernorm.weight shape: [2560] +427: model.layers.8.mlp.fc1.bias shape: [10240] +428: model.layers.8.mlp.fc1.weight shape: [10240, 2560] +429: model.layers.8.mlp.fc2.bias shape: [2560] +430: model.layers.8.mlp.fc2.weight shape: [2560, 10240] +431: model.layers.8.self_attn.dense.bias shape: [2560] +432: model.layers.8.self_attn.dense.weight shape: [2560, 2560] +433: model.layers.8.self_attn.k_proj.bias shape: [2560] +434: model.layers.8.self_attn.k_proj.weight shape: [2560, 2560] +435: model.layers.8.self_attn.q_proj.bias shape: [2560] +436: model.layers.8.self_attn.q_proj.weight shape: [2560, 2560] +437: model.layers.8.self_attn.v_proj.bias shape: [2560] +438: model.layers.8.self_attn.v_proj.weight shape: [2560, 2560] +439: model.layers.9.input_layernorm.bias shape: [2560] +440: model.layers.9.input_layernorm.weight shape: [2560] +441: model.layers.9.mlp.fc1.bias shape: [10240] +442: model.layers.9.mlp.fc1.weight shape: [10240, 2560] +443: model.layers.9.mlp.fc2.bias shape: [2560] +444: model.layers.9.mlp.fc2.weight shape: [2560, 10240] +445: model.layers.9.self_attn.dense.bias shape: [2560] +446: model.layers.9.self_attn.dense.weight shape: [2560, 2560] +447: model.layers.9.self_attn.k_proj.bias shape: [2560] +448: model.layers.9.self_attn.k_proj.weight shape: [2560, 2560] +449: model.layers.9.self_attn.q_proj.bias shape: [2560] +450: model.layers.9.self_attn.q_proj.weight shape: [2560, 2560] +451: model.layers.9.self_attn.v_proj.bias shape: [2560] +452: model.layers.9.self_attn.v_proj.weight shape: [2560, 2560] diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi2Test.TokenizerTest.approved.txt b/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi2Test.TokenizerTest.approved.txt new file mode 100644 index 0000000000..7338548917 --- /dev/null +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi2Test.TokenizerTest.approved.txt @@ -0,0 +1,3 @@ +50256, 6090, 345, 2148, 2842, 284, 4483, 17790, 286, 35484, 290, 10441, 69, 50187, 30 +50256, 19457, 0, 3423, 389, 617, 2842, 284, 4483, 35484, 290, 10441, 69, 50187, 1978, 25, 352, 13, 40058, 290, 10441, 34711, 7209, 494, 25, 41198, 35484, 290, 10441, 69, 50187, 1978, 351, 617, 7545, 290, 12498, 13, 362, 13, 40058, 290, 10441, 34711, 20698, 25, 15561, 26790, 35484, 290, 10441, 69, 50187, 1978, 351, 617, 18873, 13135, 290, 12498, 13 +50256, 2061, 546, 18120, 281, 362, 87, 1343, 513, 796, 767, 16022, 30 diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Medium128KShapeTest.approved.txt b/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Medium128KShapeTest.approved.txt new file mode 100644 index 0000000000..277f686aa7 --- /dev/null +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Medium128KShapeTest.approved.txt @@ -0,0 +1,243 @@ +0: lm_head.weight shape: [32064, 5120] +1: model.embed_tokens.weight shape: [32064, 5120] +2: model.layers.0.input_layernorm.weight shape: [5120] +3: model.layers.0.mlp.down_proj.weight shape: [5120, 17920] +4: model.layers.0.mlp.gate_up_proj.weight shape: [35840, 5120] +5: model.layers.0.post_attention_layernorm.weight shape: [5120] +6: model.layers.0.self_attn.o_proj.weight shape: [5120, 5120] +7: model.layers.0.self_attn.qkv_proj.weight shape: [7680, 5120] +8: model.layers.1.input_layernorm.weight shape: [5120] +9: model.layers.1.mlp.down_proj.weight shape: [5120, 17920] +10: model.layers.1.mlp.gate_up_proj.weight shape: [35840, 5120] +11: model.layers.1.post_attention_layernorm.weight shape: [5120] +12: model.layers.1.self_attn.o_proj.weight shape: [5120, 5120] +13: model.layers.1.self_attn.qkv_proj.weight shape: [7680, 5120] +14: model.layers.10.input_layernorm.weight shape: [5120] +15: model.layers.10.mlp.down_proj.weight shape: [5120, 17920] +16: model.layers.10.mlp.gate_up_proj.weight shape: [35840, 5120] +17: model.layers.10.post_attention_layernorm.weight shape: [5120] +18: model.layers.10.self_attn.o_proj.weight shape: [5120, 5120] +19: model.layers.10.self_attn.qkv_proj.weight shape: [7680, 5120] +20: model.layers.11.input_layernorm.weight shape: [5120] +21: model.layers.11.mlp.down_proj.weight shape: [5120, 17920] +22: model.layers.11.mlp.gate_up_proj.weight shape: [35840, 5120] +23: model.layers.11.post_attention_layernorm.weight shape: [5120] +24: model.layers.11.self_attn.o_proj.weight shape: [5120, 5120] +25: model.layers.11.self_attn.qkv_proj.weight shape: [7680, 5120] +26: model.layers.12.input_layernorm.weight shape: [5120] +27: model.layers.12.mlp.down_proj.weight shape: [5120, 17920] +28: model.layers.12.mlp.gate_up_proj.weight shape: [35840, 5120] +29: model.layers.12.post_attention_layernorm.weight shape: [5120] +30: model.layers.12.self_attn.o_proj.weight shape: [5120, 5120] +31: model.layers.12.self_attn.qkv_proj.weight shape: [7680, 5120] +32: model.layers.13.input_layernorm.weight shape: [5120] +33: model.layers.13.mlp.down_proj.weight shape: [5120, 17920] +34: model.layers.13.mlp.gate_up_proj.weight shape: [35840, 5120] +35: model.layers.13.post_attention_layernorm.weight shape: [5120] +36: model.layers.13.self_attn.o_proj.weight shape: [5120, 5120] +37: model.layers.13.self_attn.qkv_proj.weight shape: [7680, 5120] +38: model.layers.14.input_layernorm.weight shape: [5120] +39: model.layers.14.mlp.down_proj.weight shape: [5120, 17920] +40: model.layers.14.mlp.gate_up_proj.weight shape: [35840, 5120] +41: model.layers.14.post_attention_layernorm.weight shape: [5120] +42: model.layers.14.self_attn.o_proj.weight shape: [5120, 5120] +43: model.layers.14.self_attn.qkv_proj.weight shape: [7680, 5120] +44: model.layers.15.input_layernorm.weight shape: [5120] +45: model.layers.15.mlp.down_proj.weight shape: [5120, 17920] +46: model.layers.15.mlp.gate_up_proj.weight shape: [35840, 5120] +47: model.layers.15.post_attention_layernorm.weight shape: [5120] +48: model.layers.15.self_attn.o_proj.weight shape: [5120, 5120] +49: model.layers.15.self_attn.qkv_proj.weight shape: [7680, 5120] +50: model.layers.16.input_layernorm.weight shape: [5120] +51: model.layers.16.mlp.down_proj.weight shape: [5120, 17920] +52: model.layers.16.mlp.gate_up_proj.weight shape: [35840, 5120] +53: model.layers.16.post_attention_layernorm.weight shape: [5120] +54: model.layers.16.self_attn.o_proj.weight shape: [5120, 5120] +55: model.layers.16.self_attn.qkv_proj.weight shape: [7680, 5120] +56: model.layers.17.input_layernorm.weight shape: [5120] +57: model.layers.17.mlp.down_proj.weight shape: [5120, 17920] +58: model.layers.17.mlp.gate_up_proj.weight shape: [35840, 5120] +59: model.layers.17.post_attention_layernorm.weight shape: [5120] +60: model.layers.17.self_attn.o_proj.weight shape: [5120, 5120] +61: model.layers.17.self_attn.qkv_proj.weight shape: [7680, 5120] +62: model.layers.18.input_layernorm.weight shape: [5120] +63: model.layers.18.mlp.down_proj.weight shape: [5120, 17920] +64: model.layers.18.mlp.gate_up_proj.weight shape: [35840, 5120] +65: model.layers.18.post_attention_layernorm.weight shape: [5120] +66: model.layers.18.self_attn.o_proj.weight shape: [5120, 5120] +67: model.layers.18.self_attn.qkv_proj.weight shape: [7680, 5120] +68: model.layers.19.input_layernorm.weight shape: [5120] +69: model.layers.19.mlp.down_proj.weight shape: [5120, 17920] +70: model.layers.19.mlp.gate_up_proj.weight shape: [35840, 5120] +71: model.layers.19.post_attention_layernorm.weight shape: [5120] +72: model.layers.19.self_attn.o_proj.weight shape: [5120, 5120] +73: model.layers.19.self_attn.qkv_proj.weight shape: [7680, 5120] +74: model.layers.2.input_layernorm.weight shape: [5120] +75: model.layers.2.mlp.down_proj.weight shape: [5120, 17920] +76: model.layers.2.mlp.gate_up_proj.weight shape: [35840, 5120] +77: model.layers.2.post_attention_layernorm.weight shape: [5120] +78: model.layers.2.self_attn.o_proj.weight shape: [5120, 5120] +79: model.layers.2.self_attn.qkv_proj.weight shape: [7680, 5120] +80: model.layers.20.input_layernorm.weight shape: [5120] +81: model.layers.20.mlp.down_proj.weight shape: [5120, 17920] +82: model.layers.20.mlp.gate_up_proj.weight shape: [35840, 5120] +83: model.layers.20.post_attention_layernorm.weight shape: [5120] +84: model.layers.20.self_attn.o_proj.weight shape: [5120, 5120] +85: model.layers.20.self_attn.qkv_proj.weight shape: [7680, 5120] +86: model.layers.21.input_layernorm.weight shape: [5120] +87: model.layers.21.mlp.down_proj.weight shape: [5120, 17920] +88: model.layers.21.mlp.gate_up_proj.weight shape: [35840, 5120] +89: model.layers.21.post_attention_layernorm.weight shape: [5120] +90: model.layers.21.self_attn.o_proj.weight shape: [5120, 5120] +91: model.layers.21.self_attn.qkv_proj.weight shape: [7680, 5120] +92: model.layers.22.input_layernorm.weight shape: [5120] +93: model.layers.22.mlp.down_proj.weight shape: [5120, 17920] +94: model.layers.22.mlp.gate_up_proj.weight shape: [35840, 5120] +95: model.layers.22.post_attention_layernorm.weight shape: [5120] +96: model.layers.22.self_attn.o_proj.weight shape: [5120, 5120] +97: model.layers.22.self_attn.qkv_proj.weight shape: [7680, 5120] +98: model.layers.23.input_layernorm.weight shape: [5120] +99: model.layers.23.mlp.down_proj.weight shape: [5120, 17920] +100: model.layers.23.mlp.gate_up_proj.weight shape: [35840, 5120] +101: model.layers.23.post_attention_layernorm.weight shape: [5120] +102: model.layers.23.self_attn.o_proj.weight shape: [5120, 5120] +103: model.layers.23.self_attn.qkv_proj.weight shape: [7680, 5120] +104: model.layers.24.input_layernorm.weight shape: [5120] +105: model.layers.24.mlp.down_proj.weight shape: [5120, 17920] +106: model.layers.24.mlp.gate_up_proj.weight shape: [35840, 5120] +107: model.layers.24.post_attention_layernorm.weight shape: [5120] +108: model.layers.24.self_attn.o_proj.weight shape: [5120, 5120] +109: model.layers.24.self_attn.qkv_proj.weight shape: [7680, 5120] +110: model.layers.25.input_layernorm.weight shape: [5120] +111: model.layers.25.mlp.down_proj.weight shape: [5120, 17920] +112: model.layers.25.mlp.gate_up_proj.weight shape: [35840, 5120] +113: model.layers.25.post_attention_layernorm.weight shape: [5120] +114: model.layers.25.self_attn.o_proj.weight shape: [5120, 5120] +115: model.layers.25.self_attn.qkv_proj.weight shape: [7680, 5120] +116: model.layers.26.input_layernorm.weight shape: [5120] +117: model.layers.26.mlp.down_proj.weight shape: [5120, 17920] +118: model.layers.26.mlp.gate_up_proj.weight shape: [35840, 5120] +119: model.layers.26.post_attention_layernorm.weight shape: [5120] +120: model.layers.26.self_attn.o_proj.weight shape: [5120, 5120] +121: model.layers.26.self_attn.qkv_proj.weight shape: [7680, 5120] +122: model.layers.27.input_layernorm.weight shape: [5120] +123: model.layers.27.mlp.down_proj.weight shape: [5120, 17920] +124: model.layers.27.mlp.gate_up_proj.weight shape: [35840, 5120] +125: model.layers.27.post_attention_layernorm.weight shape: [5120] +126: model.layers.27.self_attn.o_proj.weight shape: [5120, 5120] +127: model.layers.27.self_attn.qkv_proj.weight shape: [7680, 5120] +128: model.layers.28.input_layernorm.weight shape: [5120] +129: model.layers.28.mlp.down_proj.weight shape: [5120, 17920] +130: model.layers.28.mlp.gate_up_proj.weight shape: [35840, 5120] +131: model.layers.28.post_attention_layernorm.weight shape: [5120] +132: model.layers.28.self_attn.o_proj.weight shape: [5120, 5120] +133: model.layers.28.self_attn.qkv_proj.weight shape: [7680, 5120] +134: model.layers.29.input_layernorm.weight shape: [5120] +135: model.layers.29.mlp.down_proj.weight shape: [5120, 17920] +136: model.layers.29.mlp.gate_up_proj.weight shape: [35840, 5120] +137: model.layers.29.post_attention_layernorm.weight shape: [5120] +138: model.layers.29.self_attn.o_proj.weight shape: [5120, 5120] +139: model.layers.29.self_attn.qkv_proj.weight shape: [7680, 5120] +140: model.layers.3.input_layernorm.weight shape: [5120] +141: model.layers.3.mlp.down_proj.weight shape: [5120, 17920] +142: model.layers.3.mlp.gate_up_proj.weight shape: [35840, 5120] +143: model.layers.3.post_attention_layernorm.weight shape: [5120] +144: model.layers.3.self_attn.o_proj.weight shape: [5120, 5120] +145: model.layers.3.self_attn.qkv_proj.weight shape: [7680, 5120] +146: model.layers.30.input_layernorm.weight shape: [5120] +147: model.layers.30.mlp.down_proj.weight shape: [5120, 17920] +148: model.layers.30.mlp.gate_up_proj.weight shape: [35840, 5120] +149: model.layers.30.post_attention_layernorm.weight shape: [5120] +150: model.layers.30.self_attn.o_proj.weight shape: [5120, 5120] +151: model.layers.30.self_attn.qkv_proj.weight shape: [7680, 5120] +152: model.layers.31.input_layernorm.weight shape: [5120] +153: model.layers.31.mlp.down_proj.weight shape: [5120, 17920] +154: model.layers.31.mlp.gate_up_proj.weight shape: [35840, 5120] +155: model.layers.31.post_attention_layernorm.weight shape: [5120] +156: model.layers.31.self_attn.o_proj.weight shape: [5120, 5120] +157: model.layers.31.self_attn.qkv_proj.weight shape: [7680, 5120] +158: model.layers.32.input_layernorm.weight shape: [5120] +159: model.layers.32.mlp.down_proj.weight shape: [5120, 17920] +160: model.layers.32.mlp.gate_up_proj.weight shape: [35840, 5120] +161: model.layers.32.post_attention_layernorm.weight shape: [5120] +162: model.layers.32.self_attn.o_proj.weight shape: [5120, 5120] +163: model.layers.32.self_attn.qkv_proj.weight shape: [7680, 5120] +164: model.layers.33.input_layernorm.weight shape: [5120] +165: model.layers.33.mlp.down_proj.weight shape: [5120, 17920] +166: model.layers.33.mlp.gate_up_proj.weight shape: [35840, 5120] +167: model.layers.33.post_attention_layernorm.weight shape: [5120] +168: model.layers.33.self_attn.o_proj.weight shape: [5120, 5120] +169: model.layers.33.self_attn.qkv_proj.weight shape: [7680, 5120] +170: model.layers.34.input_layernorm.weight shape: [5120] +171: model.layers.34.mlp.down_proj.weight shape: [5120, 17920] +172: model.layers.34.mlp.gate_up_proj.weight shape: [35840, 5120] +173: model.layers.34.post_attention_layernorm.weight shape: [5120] +174: model.layers.34.self_attn.o_proj.weight shape: [5120, 5120] +175: model.layers.34.self_attn.qkv_proj.weight shape: [7680, 5120] +176: model.layers.35.input_layernorm.weight shape: [5120] +177: model.layers.35.mlp.down_proj.weight shape: [5120, 17920] +178: model.layers.35.mlp.gate_up_proj.weight shape: [35840, 5120] +179: model.layers.35.post_attention_layernorm.weight shape: [5120] +180: model.layers.35.self_attn.o_proj.weight shape: [5120, 5120] +181: model.layers.35.self_attn.qkv_proj.weight shape: [7680, 5120] +182: model.layers.36.input_layernorm.weight shape: [5120] +183: model.layers.36.mlp.down_proj.weight shape: [5120, 17920] +184: model.layers.36.mlp.gate_up_proj.weight shape: [35840, 5120] +185: model.layers.36.post_attention_layernorm.weight shape: [5120] +186: model.layers.36.self_attn.o_proj.weight shape: [5120, 5120] +187: model.layers.36.self_attn.qkv_proj.weight shape: [7680, 5120] +188: model.layers.37.input_layernorm.weight shape: [5120] +189: model.layers.37.mlp.down_proj.weight shape: [5120, 17920] +190: model.layers.37.mlp.gate_up_proj.weight shape: [35840, 5120] +191: model.layers.37.post_attention_layernorm.weight shape: [5120] +192: model.layers.37.self_attn.o_proj.weight shape: [5120, 5120] +193: model.layers.37.self_attn.qkv_proj.weight shape: [7680, 5120] +194: model.layers.38.input_layernorm.weight shape: [5120] +195: model.layers.38.mlp.down_proj.weight shape: [5120, 17920] +196: model.layers.38.mlp.gate_up_proj.weight shape: [35840, 5120] +197: model.layers.38.post_attention_layernorm.weight shape: [5120] +198: model.layers.38.self_attn.o_proj.weight shape: [5120, 5120] +199: model.layers.38.self_attn.qkv_proj.weight shape: [7680, 5120] +200: model.layers.39.input_layernorm.weight shape: [5120] +201: model.layers.39.mlp.down_proj.weight shape: [5120, 17920] +202: model.layers.39.mlp.gate_up_proj.weight shape: [35840, 5120] +203: model.layers.39.post_attention_layernorm.weight shape: [5120] +204: model.layers.39.self_attn.o_proj.weight shape: [5120, 5120] +205: model.layers.39.self_attn.qkv_proj.weight shape: [7680, 5120] +206: model.layers.4.input_layernorm.weight shape: [5120] +207: model.layers.4.mlp.down_proj.weight shape: [5120, 17920] +208: model.layers.4.mlp.gate_up_proj.weight shape: [35840, 5120] +209: model.layers.4.post_attention_layernorm.weight shape: [5120] +210: model.layers.4.self_attn.o_proj.weight shape: [5120, 5120] +211: model.layers.4.self_attn.qkv_proj.weight shape: [7680, 5120] +212: model.layers.5.input_layernorm.weight shape: [5120] +213: model.layers.5.mlp.down_proj.weight shape: [5120, 17920] +214: model.layers.5.mlp.gate_up_proj.weight shape: [35840, 5120] +215: model.layers.5.post_attention_layernorm.weight shape: [5120] +216: model.layers.5.self_attn.o_proj.weight shape: [5120, 5120] +217: model.layers.5.self_attn.qkv_proj.weight shape: [7680, 5120] +218: model.layers.6.input_layernorm.weight shape: [5120] +219: model.layers.6.mlp.down_proj.weight shape: [5120, 17920] +220: model.layers.6.mlp.gate_up_proj.weight shape: [35840, 5120] +221: model.layers.6.post_attention_layernorm.weight shape: [5120] +222: model.layers.6.self_attn.o_proj.weight shape: [5120, 5120] +223: model.layers.6.self_attn.qkv_proj.weight shape: [7680, 5120] +224: model.layers.7.input_layernorm.weight shape: [5120] +225: model.layers.7.mlp.down_proj.weight shape: [5120, 17920] +226: model.layers.7.mlp.gate_up_proj.weight shape: [35840, 5120] +227: model.layers.7.post_attention_layernorm.weight shape: [5120] +228: model.layers.7.self_attn.o_proj.weight shape: [5120, 5120] +229: model.layers.7.self_attn.qkv_proj.weight shape: [7680, 5120] +230: model.layers.8.input_layernorm.weight shape: [5120] +231: model.layers.8.mlp.down_proj.weight shape: [5120, 17920] +232: model.layers.8.mlp.gate_up_proj.weight shape: [35840, 5120] +233: model.layers.8.post_attention_layernorm.weight shape: [5120] +234: model.layers.8.self_attn.o_proj.weight shape: [5120, 5120] +235: model.layers.8.self_attn.qkv_proj.weight shape: [7680, 5120] +236: model.layers.9.input_layernorm.weight shape: [5120] +237: model.layers.9.mlp.down_proj.weight shape: [5120, 17920] +238: model.layers.9.mlp.gate_up_proj.weight shape: [35840, 5120] +239: model.layers.9.post_attention_layernorm.weight shape: [5120] +240: model.layers.9.self_attn.o_proj.weight shape: [5120, 5120] +241: model.layers.9.self_attn.qkv_proj.weight shape: [7680, 5120] +242: model.norm.weight shape: [5120] diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Medium4KShapeTest.approved.txt b/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Medium4KShapeTest.approved.txt new file mode 100644 index 0000000000..277f686aa7 --- /dev/null +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Medium4KShapeTest.approved.txt @@ -0,0 +1,243 @@ +0: lm_head.weight shape: [32064, 5120] +1: model.embed_tokens.weight shape: [32064, 5120] +2: model.layers.0.input_layernorm.weight shape: [5120] +3: model.layers.0.mlp.down_proj.weight shape: [5120, 17920] +4: model.layers.0.mlp.gate_up_proj.weight shape: [35840, 5120] +5: model.layers.0.post_attention_layernorm.weight shape: [5120] +6: model.layers.0.self_attn.o_proj.weight shape: [5120, 5120] +7: model.layers.0.self_attn.qkv_proj.weight shape: [7680, 5120] +8: model.layers.1.input_layernorm.weight shape: [5120] +9: model.layers.1.mlp.down_proj.weight shape: [5120, 17920] +10: model.layers.1.mlp.gate_up_proj.weight shape: [35840, 5120] +11: model.layers.1.post_attention_layernorm.weight shape: [5120] +12: model.layers.1.self_attn.o_proj.weight shape: [5120, 5120] +13: model.layers.1.self_attn.qkv_proj.weight shape: [7680, 5120] +14: model.layers.10.input_layernorm.weight shape: [5120] +15: model.layers.10.mlp.down_proj.weight shape: [5120, 17920] +16: model.layers.10.mlp.gate_up_proj.weight shape: [35840, 5120] +17: model.layers.10.post_attention_layernorm.weight shape: [5120] +18: model.layers.10.self_attn.o_proj.weight shape: [5120, 5120] +19: model.layers.10.self_attn.qkv_proj.weight shape: [7680, 5120] +20: model.layers.11.input_layernorm.weight shape: [5120] +21: model.layers.11.mlp.down_proj.weight shape: [5120, 17920] +22: model.layers.11.mlp.gate_up_proj.weight shape: [35840, 5120] +23: model.layers.11.post_attention_layernorm.weight shape: [5120] +24: model.layers.11.self_attn.o_proj.weight shape: [5120, 5120] +25: model.layers.11.self_attn.qkv_proj.weight shape: [7680, 5120] +26: model.layers.12.input_layernorm.weight shape: [5120] +27: model.layers.12.mlp.down_proj.weight shape: [5120, 17920] +28: model.layers.12.mlp.gate_up_proj.weight shape: [35840, 5120] +29: model.layers.12.post_attention_layernorm.weight shape: [5120] +30: model.layers.12.self_attn.o_proj.weight shape: [5120, 5120] +31: model.layers.12.self_attn.qkv_proj.weight shape: [7680, 5120] +32: model.layers.13.input_layernorm.weight shape: [5120] +33: model.layers.13.mlp.down_proj.weight shape: [5120, 17920] +34: model.layers.13.mlp.gate_up_proj.weight shape: [35840, 5120] +35: model.layers.13.post_attention_layernorm.weight shape: [5120] +36: model.layers.13.self_attn.o_proj.weight shape: [5120, 5120] +37: model.layers.13.self_attn.qkv_proj.weight shape: [7680, 5120] +38: model.layers.14.input_layernorm.weight shape: [5120] +39: model.layers.14.mlp.down_proj.weight shape: [5120, 17920] +40: model.layers.14.mlp.gate_up_proj.weight shape: [35840, 5120] +41: model.layers.14.post_attention_layernorm.weight shape: [5120] +42: model.layers.14.self_attn.o_proj.weight shape: [5120, 5120] +43: model.layers.14.self_attn.qkv_proj.weight shape: [7680, 5120] +44: model.layers.15.input_layernorm.weight shape: [5120] +45: model.layers.15.mlp.down_proj.weight shape: [5120, 17920] +46: model.layers.15.mlp.gate_up_proj.weight shape: [35840, 5120] +47: model.layers.15.post_attention_layernorm.weight shape: [5120] +48: model.layers.15.self_attn.o_proj.weight shape: [5120, 5120] +49: model.layers.15.self_attn.qkv_proj.weight shape: [7680, 5120] +50: model.layers.16.input_layernorm.weight shape: [5120] +51: model.layers.16.mlp.down_proj.weight shape: [5120, 17920] +52: model.layers.16.mlp.gate_up_proj.weight shape: [35840, 5120] +53: model.layers.16.post_attention_layernorm.weight shape: [5120] +54: model.layers.16.self_attn.o_proj.weight shape: [5120, 5120] +55: model.layers.16.self_attn.qkv_proj.weight shape: [7680, 5120] +56: model.layers.17.input_layernorm.weight shape: [5120] +57: model.layers.17.mlp.down_proj.weight shape: [5120, 17920] +58: model.layers.17.mlp.gate_up_proj.weight shape: [35840, 5120] +59: model.layers.17.post_attention_layernorm.weight shape: [5120] +60: model.layers.17.self_attn.o_proj.weight shape: [5120, 5120] +61: model.layers.17.self_attn.qkv_proj.weight shape: [7680, 5120] +62: model.layers.18.input_layernorm.weight shape: [5120] +63: model.layers.18.mlp.down_proj.weight shape: [5120, 17920] +64: model.layers.18.mlp.gate_up_proj.weight shape: [35840, 5120] +65: model.layers.18.post_attention_layernorm.weight shape: [5120] +66: model.layers.18.self_attn.o_proj.weight shape: [5120, 5120] +67: model.layers.18.self_attn.qkv_proj.weight shape: [7680, 5120] +68: model.layers.19.input_layernorm.weight shape: [5120] +69: model.layers.19.mlp.down_proj.weight shape: [5120, 17920] +70: model.layers.19.mlp.gate_up_proj.weight shape: [35840, 5120] +71: model.layers.19.post_attention_layernorm.weight shape: [5120] +72: model.layers.19.self_attn.o_proj.weight shape: [5120, 5120] +73: model.layers.19.self_attn.qkv_proj.weight shape: [7680, 5120] +74: model.layers.2.input_layernorm.weight shape: [5120] +75: model.layers.2.mlp.down_proj.weight shape: [5120, 17920] +76: model.layers.2.mlp.gate_up_proj.weight shape: [35840, 5120] +77: model.layers.2.post_attention_layernorm.weight shape: [5120] +78: model.layers.2.self_attn.o_proj.weight shape: [5120, 5120] +79: model.layers.2.self_attn.qkv_proj.weight shape: [7680, 5120] +80: model.layers.20.input_layernorm.weight shape: [5120] +81: model.layers.20.mlp.down_proj.weight shape: [5120, 17920] +82: model.layers.20.mlp.gate_up_proj.weight shape: [35840, 5120] +83: model.layers.20.post_attention_layernorm.weight shape: [5120] +84: model.layers.20.self_attn.o_proj.weight shape: [5120, 5120] +85: model.layers.20.self_attn.qkv_proj.weight shape: [7680, 5120] +86: model.layers.21.input_layernorm.weight shape: [5120] +87: model.layers.21.mlp.down_proj.weight shape: [5120, 17920] +88: model.layers.21.mlp.gate_up_proj.weight shape: [35840, 5120] +89: model.layers.21.post_attention_layernorm.weight shape: [5120] +90: model.layers.21.self_attn.o_proj.weight shape: [5120, 5120] +91: model.layers.21.self_attn.qkv_proj.weight shape: [7680, 5120] +92: model.layers.22.input_layernorm.weight shape: [5120] +93: model.layers.22.mlp.down_proj.weight shape: [5120, 17920] +94: model.layers.22.mlp.gate_up_proj.weight shape: [35840, 5120] +95: model.layers.22.post_attention_layernorm.weight shape: [5120] +96: model.layers.22.self_attn.o_proj.weight shape: [5120, 5120] +97: model.layers.22.self_attn.qkv_proj.weight shape: [7680, 5120] +98: model.layers.23.input_layernorm.weight shape: [5120] +99: model.layers.23.mlp.down_proj.weight shape: [5120, 17920] +100: model.layers.23.mlp.gate_up_proj.weight shape: [35840, 5120] +101: model.layers.23.post_attention_layernorm.weight shape: [5120] +102: model.layers.23.self_attn.o_proj.weight shape: [5120, 5120] +103: model.layers.23.self_attn.qkv_proj.weight shape: [7680, 5120] +104: model.layers.24.input_layernorm.weight shape: [5120] +105: model.layers.24.mlp.down_proj.weight shape: [5120, 17920] +106: model.layers.24.mlp.gate_up_proj.weight shape: [35840, 5120] +107: model.layers.24.post_attention_layernorm.weight shape: [5120] +108: model.layers.24.self_attn.o_proj.weight shape: [5120, 5120] +109: model.layers.24.self_attn.qkv_proj.weight shape: [7680, 5120] +110: model.layers.25.input_layernorm.weight shape: [5120] +111: model.layers.25.mlp.down_proj.weight shape: [5120, 17920] +112: model.layers.25.mlp.gate_up_proj.weight shape: [35840, 5120] +113: model.layers.25.post_attention_layernorm.weight shape: [5120] +114: model.layers.25.self_attn.o_proj.weight shape: [5120, 5120] +115: model.layers.25.self_attn.qkv_proj.weight shape: [7680, 5120] +116: model.layers.26.input_layernorm.weight shape: [5120] +117: model.layers.26.mlp.down_proj.weight shape: [5120, 17920] +118: model.layers.26.mlp.gate_up_proj.weight shape: [35840, 5120] +119: model.layers.26.post_attention_layernorm.weight shape: [5120] +120: model.layers.26.self_attn.o_proj.weight shape: [5120, 5120] +121: model.layers.26.self_attn.qkv_proj.weight shape: [7680, 5120] +122: model.layers.27.input_layernorm.weight shape: [5120] +123: model.layers.27.mlp.down_proj.weight shape: [5120, 17920] +124: model.layers.27.mlp.gate_up_proj.weight shape: [35840, 5120] +125: model.layers.27.post_attention_layernorm.weight shape: [5120] +126: model.layers.27.self_attn.o_proj.weight shape: [5120, 5120] +127: model.layers.27.self_attn.qkv_proj.weight shape: [7680, 5120] +128: model.layers.28.input_layernorm.weight shape: [5120] +129: model.layers.28.mlp.down_proj.weight shape: [5120, 17920] +130: model.layers.28.mlp.gate_up_proj.weight shape: [35840, 5120] +131: model.layers.28.post_attention_layernorm.weight shape: [5120] +132: model.layers.28.self_attn.o_proj.weight shape: [5120, 5120] +133: model.layers.28.self_attn.qkv_proj.weight shape: [7680, 5120] +134: model.layers.29.input_layernorm.weight shape: [5120] +135: model.layers.29.mlp.down_proj.weight shape: [5120, 17920] +136: model.layers.29.mlp.gate_up_proj.weight shape: [35840, 5120] +137: model.layers.29.post_attention_layernorm.weight shape: [5120] +138: model.layers.29.self_attn.o_proj.weight shape: [5120, 5120] +139: model.layers.29.self_attn.qkv_proj.weight shape: [7680, 5120] +140: model.layers.3.input_layernorm.weight shape: [5120] +141: model.layers.3.mlp.down_proj.weight shape: [5120, 17920] +142: model.layers.3.mlp.gate_up_proj.weight shape: [35840, 5120] +143: model.layers.3.post_attention_layernorm.weight shape: [5120] +144: model.layers.3.self_attn.o_proj.weight shape: [5120, 5120] +145: model.layers.3.self_attn.qkv_proj.weight shape: [7680, 5120] +146: model.layers.30.input_layernorm.weight shape: [5120] +147: model.layers.30.mlp.down_proj.weight shape: [5120, 17920] +148: model.layers.30.mlp.gate_up_proj.weight shape: [35840, 5120] +149: model.layers.30.post_attention_layernorm.weight shape: [5120] +150: model.layers.30.self_attn.o_proj.weight shape: [5120, 5120] +151: model.layers.30.self_attn.qkv_proj.weight shape: [7680, 5120] +152: model.layers.31.input_layernorm.weight shape: [5120] +153: model.layers.31.mlp.down_proj.weight shape: [5120, 17920] +154: model.layers.31.mlp.gate_up_proj.weight shape: [35840, 5120] +155: model.layers.31.post_attention_layernorm.weight shape: [5120] +156: model.layers.31.self_attn.o_proj.weight shape: [5120, 5120] +157: model.layers.31.self_attn.qkv_proj.weight shape: [7680, 5120] +158: model.layers.32.input_layernorm.weight shape: [5120] +159: model.layers.32.mlp.down_proj.weight shape: [5120, 17920] +160: model.layers.32.mlp.gate_up_proj.weight shape: [35840, 5120] +161: model.layers.32.post_attention_layernorm.weight shape: [5120] +162: model.layers.32.self_attn.o_proj.weight shape: [5120, 5120] +163: model.layers.32.self_attn.qkv_proj.weight shape: [7680, 5120] +164: model.layers.33.input_layernorm.weight shape: [5120] +165: model.layers.33.mlp.down_proj.weight shape: [5120, 17920] +166: model.layers.33.mlp.gate_up_proj.weight shape: [35840, 5120] +167: model.layers.33.post_attention_layernorm.weight shape: [5120] +168: model.layers.33.self_attn.o_proj.weight shape: [5120, 5120] +169: model.layers.33.self_attn.qkv_proj.weight shape: [7680, 5120] +170: model.layers.34.input_layernorm.weight shape: [5120] +171: model.layers.34.mlp.down_proj.weight shape: [5120, 17920] +172: model.layers.34.mlp.gate_up_proj.weight shape: [35840, 5120] +173: model.layers.34.post_attention_layernorm.weight shape: [5120] +174: model.layers.34.self_attn.o_proj.weight shape: [5120, 5120] +175: model.layers.34.self_attn.qkv_proj.weight shape: [7680, 5120] +176: model.layers.35.input_layernorm.weight shape: [5120] +177: model.layers.35.mlp.down_proj.weight shape: [5120, 17920] +178: model.layers.35.mlp.gate_up_proj.weight shape: [35840, 5120] +179: model.layers.35.post_attention_layernorm.weight shape: [5120] +180: model.layers.35.self_attn.o_proj.weight shape: [5120, 5120] +181: model.layers.35.self_attn.qkv_proj.weight shape: [7680, 5120] +182: model.layers.36.input_layernorm.weight shape: [5120] +183: model.layers.36.mlp.down_proj.weight shape: [5120, 17920] +184: model.layers.36.mlp.gate_up_proj.weight shape: [35840, 5120] +185: model.layers.36.post_attention_layernorm.weight shape: [5120] +186: model.layers.36.self_attn.o_proj.weight shape: [5120, 5120] +187: model.layers.36.self_attn.qkv_proj.weight shape: [7680, 5120] +188: model.layers.37.input_layernorm.weight shape: [5120] +189: model.layers.37.mlp.down_proj.weight shape: [5120, 17920] +190: model.layers.37.mlp.gate_up_proj.weight shape: [35840, 5120] +191: model.layers.37.post_attention_layernorm.weight shape: [5120] +192: model.layers.37.self_attn.o_proj.weight shape: [5120, 5120] +193: model.layers.37.self_attn.qkv_proj.weight shape: [7680, 5120] +194: model.layers.38.input_layernorm.weight shape: [5120] +195: model.layers.38.mlp.down_proj.weight shape: [5120, 17920] +196: model.layers.38.mlp.gate_up_proj.weight shape: [35840, 5120] +197: model.layers.38.post_attention_layernorm.weight shape: [5120] +198: model.layers.38.self_attn.o_proj.weight shape: [5120, 5120] +199: model.layers.38.self_attn.qkv_proj.weight shape: [7680, 5120] +200: model.layers.39.input_layernorm.weight shape: [5120] +201: model.layers.39.mlp.down_proj.weight shape: [5120, 17920] +202: model.layers.39.mlp.gate_up_proj.weight shape: [35840, 5120] +203: model.layers.39.post_attention_layernorm.weight shape: [5120] +204: model.layers.39.self_attn.o_proj.weight shape: [5120, 5120] +205: model.layers.39.self_attn.qkv_proj.weight shape: [7680, 5120] +206: model.layers.4.input_layernorm.weight shape: [5120] +207: model.layers.4.mlp.down_proj.weight shape: [5120, 17920] +208: model.layers.4.mlp.gate_up_proj.weight shape: [35840, 5120] +209: model.layers.4.post_attention_layernorm.weight shape: [5120] +210: model.layers.4.self_attn.o_proj.weight shape: [5120, 5120] +211: model.layers.4.self_attn.qkv_proj.weight shape: [7680, 5120] +212: model.layers.5.input_layernorm.weight shape: [5120] +213: model.layers.5.mlp.down_proj.weight shape: [5120, 17920] +214: model.layers.5.mlp.gate_up_proj.weight shape: [35840, 5120] +215: model.layers.5.post_attention_layernorm.weight shape: [5120] +216: model.layers.5.self_attn.o_proj.weight shape: [5120, 5120] +217: model.layers.5.self_attn.qkv_proj.weight shape: [7680, 5120] +218: model.layers.6.input_layernorm.weight shape: [5120] +219: model.layers.6.mlp.down_proj.weight shape: [5120, 17920] +220: model.layers.6.mlp.gate_up_proj.weight shape: [35840, 5120] +221: model.layers.6.post_attention_layernorm.weight shape: [5120] +222: model.layers.6.self_attn.o_proj.weight shape: [5120, 5120] +223: model.layers.6.self_attn.qkv_proj.weight shape: [7680, 5120] +224: model.layers.7.input_layernorm.weight shape: [5120] +225: model.layers.7.mlp.down_proj.weight shape: [5120, 17920] +226: model.layers.7.mlp.gate_up_proj.weight shape: [35840, 5120] +227: model.layers.7.post_attention_layernorm.weight shape: [5120] +228: model.layers.7.self_attn.o_proj.weight shape: [5120, 5120] +229: model.layers.7.self_attn.qkv_proj.weight shape: [7680, 5120] +230: model.layers.8.input_layernorm.weight shape: [5120] +231: model.layers.8.mlp.down_proj.weight shape: [5120, 17920] +232: model.layers.8.mlp.gate_up_proj.weight shape: [35840, 5120] +233: model.layers.8.post_attention_layernorm.weight shape: [5120] +234: model.layers.8.self_attn.o_proj.weight shape: [5120, 5120] +235: model.layers.8.self_attn.qkv_proj.weight shape: [7680, 5120] +236: model.layers.9.input_layernorm.weight shape: [5120] +237: model.layers.9.mlp.down_proj.weight shape: [5120, 17920] +238: model.layers.9.mlp.gate_up_proj.weight shape: [35840, 5120] +239: model.layers.9.post_attention_layernorm.weight shape: [5120] +240: model.layers.9.self_attn.o_proj.weight shape: [5120, 5120] +241: model.layers.9.self_attn.qkv_proj.weight shape: [7680, 5120] +242: model.norm.weight shape: [5120] diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Mini128KLayerSizeTest.approved.txt b/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Mini128KLayerSizeTest.approved.txt new file mode 100644 index 0000000000..edb1e258bb --- /dev/null +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Mini128KLayerSizeTest.approved.txt @@ -0,0 +1,34 @@ +{ + "model.layers.0": 216.01172, + "model.layers.1": 216.01172, + "model.layers.2": 216.01172, + "model.layers.3": 216.01172, + "model.layers.4": 216.01172, + "model.layers.5": 216.01172, + "model.layers.6": 216.01172, + "model.layers.7": 216.01172, + "model.layers.8": 216.01172, + "model.layers.9": 216.01172, + "model.layers.10": 216.01172, + "model.layers.11": 216.01172, + "model.layers.12": 216.01172, + "model.layers.13": 216.01172, + "model.layers.14": 216.01172, + "model.layers.15": 216.01172, + "model.layers.16": 216.01172, + "model.layers.17": 216.01172, + "model.layers.18": 216.01172, + "model.layers.19": 216.01172, + "model.layers.20": 216.01172, + "model.layers.21": 216.01172, + "model.layers.22": 216.01172, + "model.layers.23": 216.01172, + "model.layers.24": 216.01172, + "model.layers.25": 216.01172, + "model.layers.26": 216.01172, + "model.layers.27": 216.01172, + "model.layers.28": 216.01172, + "model.layers.29": 216.01172, + "model.layers.30": 216.01172, + "model.layers.31": 216.01172 +} \ No newline at end of file diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Mini128KShapeTest.approved.txt b/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Mini128KShapeTest.approved.txt new file mode 100644 index 0000000000..2278f3b67d --- /dev/null +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Mini128KShapeTest.approved.txt @@ -0,0 +1,195 @@ +0: lm_head.weight shape: [32064, 3072] +1: model.embed_tokens.weight shape: [32064, 3072] +2: model.layers.0.input_layernorm.weight shape: [3072] +3: model.layers.0.mlp.down_proj.weight shape: [3072, 8192] +4: model.layers.0.mlp.gate_up_proj.weight shape: [16384, 3072] +5: model.layers.0.post_attention_layernorm.weight shape: [3072] +6: model.layers.0.self_attn.o_proj.weight shape: [3072, 3072] +7: model.layers.0.self_attn.qkv_proj.weight shape: [9216, 3072] +8: model.layers.1.input_layernorm.weight shape: [3072] +9: model.layers.1.mlp.down_proj.weight shape: [3072, 8192] +10: model.layers.1.mlp.gate_up_proj.weight shape: [16384, 3072] +11: model.layers.1.post_attention_layernorm.weight shape: [3072] +12: model.layers.1.self_attn.o_proj.weight shape: [3072, 3072] +13: model.layers.1.self_attn.qkv_proj.weight shape: [9216, 3072] +14: model.layers.10.input_layernorm.weight shape: [3072] +15: model.layers.10.mlp.down_proj.weight shape: [3072, 8192] +16: model.layers.10.mlp.gate_up_proj.weight shape: [16384, 3072] +17: model.layers.10.post_attention_layernorm.weight shape: [3072] +18: model.layers.10.self_attn.o_proj.weight shape: [3072, 3072] +19: model.layers.10.self_attn.qkv_proj.weight shape: [9216, 3072] +20: model.layers.11.input_layernorm.weight shape: [3072] +21: model.layers.11.mlp.down_proj.weight shape: [3072, 8192] +22: model.layers.11.mlp.gate_up_proj.weight shape: [16384, 3072] +23: model.layers.11.post_attention_layernorm.weight shape: [3072] +24: model.layers.11.self_attn.o_proj.weight shape: [3072, 3072] +25: model.layers.11.self_attn.qkv_proj.weight shape: [9216, 3072] +26: model.layers.12.input_layernorm.weight shape: [3072] +27: model.layers.12.mlp.down_proj.weight shape: [3072, 8192] +28: model.layers.12.mlp.gate_up_proj.weight shape: [16384, 3072] +29: model.layers.12.post_attention_layernorm.weight shape: [3072] +30: model.layers.12.self_attn.o_proj.weight shape: [3072, 3072] +31: model.layers.12.self_attn.qkv_proj.weight shape: [9216, 3072] +32: model.layers.13.input_layernorm.weight shape: [3072] +33: model.layers.13.mlp.down_proj.weight shape: [3072, 8192] +34: model.layers.13.mlp.gate_up_proj.weight shape: [16384, 3072] +35: model.layers.13.post_attention_layernorm.weight shape: [3072] +36: model.layers.13.self_attn.o_proj.weight shape: [3072, 3072] +37: model.layers.13.self_attn.qkv_proj.weight shape: [9216, 3072] +38: model.layers.14.input_layernorm.weight shape: [3072] +39: model.layers.14.mlp.down_proj.weight shape: [3072, 8192] +40: model.layers.14.mlp.gate_up_proj.weight shape: [16384, 3072] +41: model.layers.14.post_attention_layernorm.weight shape: [3072] +42: model.layers.14.self_attn.o_proj.weight shape: [3072, 3072] +43: model.layers.14.self_attn.qkv_proj.weight shape: [9216, 3072] +44: model.layers.15.input_layernorm.weight shape: [3072] +45: model.layers.15.mlp.down_proj.weight shape: [3072, 8192] +46: model.layers.15.mlp.gate_up_proj.weight shape: [16384, 3072] +47: model.layers.15.post_attention_layernorm.weight shape: [3072] +48: model.layers.15.self_attn.o_proj.weight shape: [3072, 3072] +49: model.layers.15.self_attn.qkv_proj.weight shape: [9216, 3072] +50: model.layers.16.input_layernorm.weight shape: [3072] +51: model.layers.16.mlp.down_proj.weight shape: [3072, 8192] +52: model.layers.16.mlp.gate_up_proj.weight shape: [16384, 3072] +53: model.layers.16.post_attention_layernorm.weight shape: [3072] +54: model.layers.16.self_attn.o_proj.weight shape: [3072, 3072] +55: model.layers.16.self_attn.qkv_proj.weight shape: [9216, 3072] +56: model.layers.17.input_layernorm.weight shape: [3072] +57: model.layers.17.mlp.down_proj.weight shape: [3072, 8192] +58: model.layers.17.mlp.gate_up_proj.weight shape: [16384, 3072] +59: model.layers.17.post_attention_layernorm.weight shape: [3072] +60: model.layers.17.self_attn.o_proj.weight shape: [3072, 3072] +61: model.layers.17.self_attn.qkv_proj.weight shape: [9216, 3072] +62: model.layers.18.input_layernorm.weight shape: [3072] +63: model.layers.18.mlp.down_proj.weight shape: [3072, 8192] +64: model.layers.18.mlp.gate_up_proj.weight shape: [16384, 3072] +65: model.layers.18.post_attention_layernorm.weight shape: [3072] +66: model.layers.18.self_attn.o_proj.weight shape: [3072, 3072] +67: model.layers.18.self_attn.qkv_proj.weight shape: [9216, 3072] +68: model.layers.19.input_layernorm.weight shape: [3072] +69: model.layers.19.mlp.down_proj.weight shape: [3072, 8192] +70: model.layers.19.mlp.gate_up_proj.weight shape: [16384, 3072] +71: model.layers.19.post_attention_layernorm.weight shape: [3072] +72: model.layers.19.self_attn.o_proj.weight shape: [3072, 3072] +73: model.layers.19.self_attn.qkv_proj.weight shape: [9216, 3072] +74: model.layers.2.input_layernorm.weight shape: [3072] +75: model.layers.2.mlp.down_proj.weight shape: [3072, 8192] +76: model.layers.2.mlp.gate_up_proj.weight shape: [16384, 3072] +77: model.layers.2.post_attention_layernorm.weight shape: [3072] +78: model.layers.2.self_attn.o_proj.weight shape: [3072, 3072] +79: model.layers.2.self_attn.qkv_proj.weight shape: [9216, 3072] +80: model.layers.20.input_layernorm.weight shape: [3072] +81: model.layers.20.mlp.down_proj.weight shape: [3072, 8192] +82: model.layers.20.mlp.gate_up_proj.weight shape: [16384, 3072] +83: model.layers.20.post_attention_layernorm.weight shape: [3072] +84: model.layers.20.self_attn.o_proj.weight shape: [3072, 3072] +85: model.layers.20.self_attn.qkv_proj.weight shape: [9216, 3072] +86: model.layers.21.input_layernorm.weight shape: [3072] +87: model.layers.21.mlp.down_proj.weight shape: [3072, 8192] +88: model.layers.21.mlp.gate_up_proj.weight shape: [16384, 3072] +89: model.layers.21.post_attention_layernorm.weight shape: [3072] +90: model.layers.21.self_attn.o_proj.weight shape: [3072, 3072] +91: model.layers.21.self_attn.qkv_proj.weight shape: [9216, 3072] +92: model.layers.22.input_layernorm.weight shape: [3072] +93: model.layers.22.mlp.down_proj.weight shape: [3072, 8192] +94: model.layers.22.mlp.gate_up_proj.weight shape: [16384, 3072] +95: model.layers.22.post_attention_layernorm.weight shape: [3072] +96: model.layers.22.self_attn.o_proj.weight shape: [3072, 3072] +97: model.layers.22.self_attn.qkv_proj.weight shape: [9216, 3072] +98: model.layers.23.input_layernorm.weight shape: [3072] +99: model.layers.23.mlp.down_proj.weight shape: [3072, 8192] +100: model.layers.23.mlp.gate_up_proj.weight shape: [16384, 3072] +101: model.layers.23.post_attention_layernorm.weight shape: [3072] +102: model.layers.23.self_attn.o_proj.weight shape: [3072, 3072] +103: model.layers.23.self_attn.qkv_proj.weight shape: [9216, 3072] +104: model.layers.24.input_layernorm.weight shape: [3072] +105: model.layers.24.mlp.down_proj.weight shape: [3072, 8192] +106: model.layers.24.mlp.gate_up_proj.weight shape: [16384, 3072] +107: model.layers.24.post_attention_layernorm.weight shape: [3072] +108: model.layers.24.self_attn.o_proj.weight shape: [3072, 3072] +109: model.layers.24.self_attn.qkv_proj.weight shape: [9216, 3072] +110: model.layers.25.input_layernorm.weight shape: [3072] +111: model.layers.25.mlp.down_proj.weight shape: [3072, 8192] +112: model.layers.25.mlp.gate_up_proj.weight shape: [16384, 3072] +113: model.layers.25.post_attention_layernorm.weight shape: [3072] +114: model.layers.25.self_attn.o_proj.weight shape: [3072, 3072] +115: model.layers.25.self_attn.qkv_proj.weight shape: [9216, 3072] +116: model.layers.26.input_layernorm.weight shape: [3072] +117: model.layers.26.mlp.down_proj.weight shape: [3072, 8192] +118: model.layers.26.mlp.gate_up_proj.weight shape: [16384, 3072] +119: model.layers.26.post_attention_layernorm.weight shape: [3072] +120: model.layers.26.self_attn.o_proj.weight shape: [3072, 3072] +121: model.layers.26.self_attn.qkv_proj.weight shape: [9216, 3072] +122: model.layers.27.input_layernorm.weight shape: [3072] +123: model.layers.27.mlp.down_proj.weight shape: [3072, 8192] +124: model.layers.27.mlp.gate_up_proj.weight shape: [16384, 3072] +125: model.layers.27.post_attention_layernorm.weight shape: [3072] +126: model.layers.27.self_attn.o_proj.weight shape: [3072, 3072] +127: model.layers.27.self_attn.qkv_proj.weight shape: [9216, 3072] +128: model.layers.28.input_layernorm.weight shape: [3072] +129: model.layers.28.mlp.down_proj.weight shape: [3072, 8192] +130: model.layers.28.mlp.gate_up_proj.weight shape: [16384, 3072] +131: model.layers.28.post_attention_layernorm.weight shape: [3072] +132: model.layers.28.self_attn.o_proj.weight shape: [3072, 3072] +133: model.layers.28.self_attn.qkv_proj.weight shape: [9216, 3072] +134: model.layers.29.input_layernorm.weight shape: [3072] +135: model.layers.29.mlp.down_proj.weight shape: [3072, 8192] +136: model.layers.29.mlp.gate_up_proj.weight shape: [16384, 3072] +137: model.layers.29.post_attention_layernorm.weight shape: [3072] +138: model.layers.29.self_attn.o_proj.weight shape: [3072, 3072] +139: model.layers.29.self_attn.qkv_proj.weight shape: [9216, 3072] +140: model.layers.3.input_layernorm.weight shape: [3072] +141: model.layers.3.mlp.down_proj.weight shape: [3072, 8192] +142: model.layers.3.mlp.gate_up_proj.weight shape: [16384, 3072] +143: model.layers.3.post_attention_layernorm.weight shape: [3072] +144: model.layers.3.self_attn.o_proj.weight shape: [3072, 3072] +145: model.layers.3.self_attn.qkv_proj.weight shape: [9216, 3072] +146: model.layers.30.input_layernorm.weight shape: [3072] +147: model.layers.30.mlp.down_proj.weight shape: [3072, 8192] +148: model.layers.30.mlp.gate_up_proj.weight shape: [16384, 3072] +149: model.layers.30.post_attention_layernorm.weight shape: [3072] +150: model.layers.30.self_attn.o_proj.weight shape: [3072, 3072] +151: model.layers.30.self_attn.qkv_proj.weight shape: [9216, 3072] +152: model.layers.31.input_layernorm.weight shape: [3072] +153: model.layers.31.mlp.down_proj.weight shape: [3072, 8192] +154: model.layers.31.mlp.gate_up_proj.weight shape: [16384, 3072] +155: model.layers.31.post_attention_layernorm.weight shape: [3072] +156: model.layers.31.self_attn.o_proj.weight shape: [3072, 3072] +157: model.layers.31.self_attn.qkv_proj.weight shape: [9216, 3072] +158: model.layers.4.input_layernorm.weight shape: [3072] +159: model.layers.4.mlp.down_proj.weight shape: [3072, 8192] +160: model.layers.4.mlp.gate_up_proj.weight shape: [16384, 3072] +161: model.layers.4.post_attention_layernorm.weight shape: [3072] +162: model.layers.4.self_attn.o_proj.weight shape: [3072, 3072] +163: model.layers.4.self_attn.qkv_proj.weight shape: [9216, 3072] +164: model.layers.5.input_layernorm.weight shape: [3072] +165: model.layers.5.mlp.down_proj.weight shape: [3072, 8192] +166: model.layers.5.mlp.gate_up_proj.weight shape: [16384, 3072] +167: model.layers.5.post_attention_layernorm.weight shape: [3072] +168: model.layers.5.self_attn.o_proj.weight shape: [3072, 3072] +169: model.layers.5.self_attn.qkv_proj.weight shape: [9216, 3072] +170: model.layers.6.input_layernorm.weight shape: [3072] +171: model.layers.6.mlp.down_proj.weight shape: [3072, 8192] +172: model.layers.6.mlp.gate_up_proj.weight shape: [16384, 3072] +173: model.layers.6.post_attention_layernorm.weight shape: [3072] +174: model.layers.6.self_attn.o_proj.weight shape: [3072, 3072] +175: model.layers.6.self_attn.qkv_proj.weight shape: [9216, 3072] +176: model.layers.7.input_layernorm.weight shape: [3072] +177: model.layers.7.mlp.down_proj.weight shape: [3072, 8192] +178: model.layers.7.mlp.gate_up_proj.weight shape: [16384, 3072] +179: model.layers.7.post_attention_layernorm.weight shape: [3072] +180: model.layers.7.self_attn.o_proj.weight shape: [3072, 3072] +181: model.layers.7.self_attn.qkv_proj.weight shape: [9216, 3072] +182: model.layers.8.input_layernorm.weight shape: [3072] +183: model.layers.8.mlp.down_proj.weight shape: [3072, 8192] +184: model.layers.8.mlp.gate_up_proj.weight shape: [16384, 3072] +185: model.layers.8.post_attention_layernorm.weight shape: [3072] +186: model.layers.8.self_attn.o_proj.weight shape: [3072, 3072] +187: model.layers.8.self_attn.qkv_proj.weight shape: [9216, 3072] +188: model.layers.9.input_layernorm.weight shape: [3072] +189: model.layers.9.mlp.down_proj.weight shape: [3072, 8192] +190: model.layers.9.mlp.gate_up_proj.weight shape: [16384, 3072] +191: model.layers.9.post_attention_layernorm.weight shape: [3072] +192: model.layers.9.self_attn.o_proj.weight shape: [3072, 3072] +193: model.layers.9.self_attn.qkv_proj.weight shape: [9216, 3072] +194: model.norm.weight shape: [3072] diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Mini4KShapeTest.approved.txt b/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Mini4KShapeTest.approved.txt new file mode 100644 index 0000000000..2278f3b67d --- /dev/null +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Mini4KShapeTest.approved.txt @@ -0,0 +1,195 @@ +0: lm_head.weight shape: [32064, 3072] +1: model.embed_tokens.weight shape: [32064, 3072] +2: model.layers.0.input_layernorm.weight shape: [3072] +3: model.layers.0.mlp.down_proj.weight shape: [3072, 8192] +4: model.layers.0.mlp.gate_up_proj.weight shape: [16384, 3072] +5: model.layers.0.post_attention_layernorm.weight shape: [3072] +6: model.layers.0.self_attn.o_proj.weight shape: [3072, 3072] +7: model.layers.0.self_attn.qkv_proj.weight shape: [9216, 3072] +8: model.layers.1.input_layernorm.weight shape: [3072] +9: model.layers.1.mlp.down_proj.weight shape: [3072, 8192] +10: model.layers.1.mlp.gate_up_proj.weight shape: [16384, 3072] +11: model.layers.1.post_attention_layernorm.weight shape: [3072] +12: model.layers.1.self_attn.o_proj.weight shape: [3072, 3072] +13: model.layers.1.self_attn.qkv_proj.weight shape: [9216, 3072] +14: model.layers.10.input_layernorm.weight shape: [3072] +15: model.layers.10.mlp.down_proj.weight shape: [3072, 8192] +16: model.layers.10.mlp.gate_up_proj.weight shape: [16384, 3072] +17: model.layers.10.post_attention_layernorm.weight shape: [3072] +18: model.layers.10.self_attn.o_proj.weight shape: [3072, 3072] +19: model.layers.10.self_attn.qkv_proj.weight shape: [9216, 3072] +20: model.layers.11.input_layernorm.weight shape: [3072] +21: model.layers.11.mlp.down_proj.weight shape: [3072, 8192] +22: model.layers.11.mlp.gate_up_proj.weight shape: [16384, 3072] +23: model.layers.11.post_attention_layernorm.weight shape: [3072] +24: model.layers.11.self_attn.o_proj.weight shape: [3072, 3072] +25: model.layers.11.self_attn.qkv_proj.weight shape: [9216, 3072] +26: model.layers.12.input_layernorm.weight shape: [3072] +27: model.layers.12.mlp.down_proj.weight shape: [3072, 8192] +28: model.layers.12.mlp.gate_up_proj.weight shape: [16384, 3072] +29: model.layers.12.post_attention_layernorm.weight shape: [3072] +30: model.layers.12.self_attn.o_proj.weight shape: [3072, 3072] +31: model.layers.12.self_attn.qkv_proj.weight shape: [9216, 3072] +32: model.layers.13.input_layernorm.weight shape: [3072] +33: model.layers.13.mlp.down_proj.weight shape: [3072, 8192] +34: model.layers.13.mlp.gate_up_proj.weight shape: [16384, 3072] +35: model.layers.13.post_attention_layernorm.weight shape: [3072] +36: model.layers.13.self_attn.o_proj.weight shape: [3072, 3072] +37: model.layers.13.self_attn.qkv_proj.weight shape: [9216, 3072] +38: model.layers.14.input_layernorm.weight shape: [3072] +39: model.layers.14.mlp.down_proj.weight shape: [3072, 8192] +40: model.layers.14.mlp.gate_up_proj.weight shape: [16384, 3072] +41: model.layers.14.post_attention_layernorm.weight shape: [3072] +42: model.layers.14.self_attn.o_proj.weight shape: [3072, 3072] +43: model.layers.14.self_attn.qkv_proj.weight shape: [9216, 3072] +44: model.layers.15.input_layernorm.weight shape: [3072] +45: model.layers.15.mlp.down_proj.weight shape: [3072, 8192] +46: model.layers.15.mlp.gate_up_proj.weight shape: [16384, 3072] +47: model.layers.15.post_attention_layernorm.weight shape: [3072] +48: model.layers.15.self_attn.o_proj.weight shape: [3072, 3072] +49: model.layers.15.self_attn.qkv_proj.weight shape: [9216, 3072] +50: model.layers.16.input_layernorm.weight shape: [3072] +51: model.layers.16.mlp.down_proj.weight shape: [3072, 8192] +52: model.layers.16.mlp.gate_up_proj.weight shape: [16384, 3072] +53: model.layers.16.post_attention_layernorm.weight shape: [3072] +54: model.layers.16.self_attn.o_proj.weight shape: [3072, 3072] +55: model.layers.16.self_attn.qkv_proj.weight shape: [9216, 3072] +56: model.layers.17.input_layernorm.weight shape: [3072] +57: model.layers.17.mlp.down_proj.weight shape: [3072, 8192] +58: model.layers.17.mlp.gate_up_proj.weight shape: [16384, 3072] +59: model.layers.17.post_attention_layernorm.weight shape: [3072] +60: model.layers.17.self_attn.o_proj.weight shape: [3072, 3072] +61: model.layers.17.self_attn.qkv_proj.weight shape: [9216, 3072] +62: model.layers.18.input_layernorm.weight shape: [3072] +63: model.layers.18.mlp.down_proj.weight shape: [3072, 8192] +64: model.layers.18.mlp.gate_up_proj.weight shape: [16384, 3072] +65: model.layers.18.post_attention_layernorm.weight shape: [3072] +66: model.layers.18.self_attn.o_proj.weight shape: [3072, 3072] +67: model.layers.18.self_attn.qkv_proj.weight shape: [9216, 3072] +68: model.layers.19.input_layernorm.weight shape: [3072] +69: model.layers.19.mlp.down_proj.weight shape: [3072, 8192] +70: model.layers.19.mlp.gate_up_proj.weight shape: [16384, 3072] +71: model.layers.19.post_attention_layernorm.weight shape: [3072] +72: model.layers.19.self_attn.o_proj.weight shape: [3072, 3072] +73: model.layers.19.self_attn.qkv_proj.weight shape: [9216, 3072] +74: model.layers.2.input_layernorm.weight shape: [3072] +75: model.layers.2.mlp.down_proj.weight shape: [3072, 8192] +76: model.layers.2.mlp.gate_up_proj.weight shape: [16384, 3072] +77: model.layers.2.post_attention_layernorm.weight shape: [3072] +78: model.layers.2.self_attn.o_proj.weight shape: [3072, 3072] +79: model.layers.2.self_attn.qkv_proj.weight shape: [9216, 3072] +80: model.layers.20.input_layernorm.weight shape: [3072] +81: model.layers.20.mlp.down_proj.weight shape: [3072, 8192] +82: model.layers.20.mlp.gate_up_proj.weight shape: [16384, 3072] +83: model.layers.20.post_attention_layernorm.weight shape: [3072] +84: model.layers.20.self_attn.o_proj.weight shape: [3072, 3072] +85: model.layers.20.self_attn.qkv_proj.weight shape: [9216, 3072] +86: model.layers.21.input_layernorm.weight shape: [3072] +87: model.layers.21.mlp.down_proj.weight shape: [3072, 8192] +88: model.layers.21.mlp.gate_up_proj.weight shape: [16384, 3072] +89: model.layers.21.post_attention_layernorm.weight shape: [3072] +90: model.layers.21.self_attn.o_proj.weight shape: [3072, 3072] +91: model.layers.21.self_attn.qkv_proj.weight shape: [9216, 3072] +92: model.layers.22.input_layernorm.weight shape: [3072] +93: model.layers.22.mlp.down_proj.weight shape: [3072, 8192] +94: model.layers.22.mlp.gate_up_proj.weight shape: [16384, 3072] +95: model.layers.22.post_attention_layernorm.weight shape: [3072] +96: model.layers.22.self_attn.o_proj.weight shape: [3072, 3072] +97: model.layers.22.self_attn.qkv_proj.weight shape: [9216, 3072] +98: model.layers.23.input_layernorm.weight shape: [3072] +99: model.layers.23.mlp.down_proj.weight shape: [3072, 8192] +100: model.layers.23.mlp.gate_up_proj.weight shape: [16384, 3072] +101: model.layers.23.post_attention_layernorm.weight shape: [3072] +102: model.layers.23.self_attn.o_proj.weight shape: [3072, 3072] +103: model.layers.23.self_attn.qkv_proj.weight shape: [9216, 3072] +104: model.layers.24.input_layernorm.weight shape: [3072] +105: model.layers.24.mlp.down_proj.weight shape: [3072, 8192] +106: model.layers.24.mlp.gate_up_proj.weight shape: [16384, 3072] +107: model.layers.24.post_attention_layernorm.weight shape: [3072] +108: model.layers.24.self_attn.o_proj.weight shape: [3072, 3072] +109: model.layers.24.self_attn.qkv_proj.weight shape: [9216, 3072] +110: model.layers.25.input_layernorm.weight shape: [3072] +111: model.layers.25.mlp.down_proj.weight shape: [3072, 8192] +112: model.layers.25.mlp.gate_up_proj.weight shape: [16384, 3072] +113: model.layers.25.post_attention_layernorm.weight shape: [3072] +114: model.layers.25.self_attn.o_proj.weight shape: [3072, 3072] +115: model.layers.25.self_attn.qkv_proj.weight shape: [9216, 3072] +116: model.layers.26.input_layernorm.weight shape: [3072] +117: model.layers.26.mlp.down_proj.weight shape: [3072, 8192] +118: model.layers.26.mlp.gate_up_proj.weight shape: [16384, 3072] +119: model.layers.26.post_attention_layernorm.weight shape: [3072] +120: model.layers.26.self_attn.o_proj.weight shape: [3072, 3072] +121: model.layers.26.self_attn.qkv_proj.weight shape: [9216, 3072] +122: model.layers.27.input_layernorm.weight shape: [3072] +123: model.layers.27.mlp.down_proj.weight shape: [3072, 8192] +124: model.layers.27.mlp.gate_up_proj.weight shape: [16384, 3072] +125: model.layers.27.post_attention_layernorm.weight shape: [3072] +126: model.layers.27.self_attn.o_proj.weight shape: [3072, 3072] +127: model.layers.27.self_attn.qkv_proj.weight shape: [9216, 3072] +128: model.layers.28.input_layernorm.weight shape: [3072] +129: model.layers.28.mlp.down_proj.weight shape: [3072, 8192] +130: model.layers.28.mlp.gate_up_proj.weight shape: [16384, 3072] +131: model.layers.28.post_attention_layernorm.weight shape: [3072] +132: model.layers.28.self_attn.o_proj.weight shape: [3072, 3072] +133: model.layers.28.self_attn.qkv_proj.weight shape: [9216, 3072] +134: model.layers.29.input_layernorm.weight shape: [3072] +135: model.layers.29.mlp.down_proj.weight shape: [3072, 8192] +136: model.layers.29.mlp.gate_up_proj.weight shape: [16384, 3072] +137: model.layers.29.post_attention_layernorm.weight shape: [3072] +138: model.layers.29.self_attn.o_proj.weight shape: [3072, 3072] +139: model.layers.29.self_attn.qkv_proj.weight shape: [9216, 3072] +140: model.layers.3.input_layernorm.weight shape: [3072] +141: model.layers.3.mlp.down_proj.weight shape: [3072, 8192] +142: model.layers.3.mlp.gate_up_proj.weight shape: [16384, 3072] +143: model.layers.3.post_attention_layernorm.weight shape: [3072] +144: model.layers.3.self_attn.o_proj.weight shape: [3072, 3072] +145: model.layers.3.self_attn.qkv_proj.weight shape: [9216, 3072] +146: model.layers.30.input_layernorm.weight shape: [3072] +147: model.layers.30.mlp.down_proj.weight shape: [3072, 8192] +148: model.layers.30.mlp.gate_up_proj.weight shape: [16384, 3072] +149: model.layers.30.post_attention_layernorm.weight shape: [3072] +150: model.layers.30.self_attn.o_proj.weight shape: [3072, 3072] +151: model.layers.30.self_attn.qkv_proj.weight shape: [9216, 3072] +152: model.layers.31.input_layernorm.weight shape: [3072] +153: model.layers.31.mlp.down_proj.weight shape: [3072, 8192] +154: model.layers.31.mlp.gate_up_proj.weight shape: [16384, 3072] +155: model.layers.31.post_attention_layernorm.weight shape: [3072] +156: model.layers.31.self_attn.o_proj.weight shape: [3072, 3072] +157: model.layers.31.self_attn.qkv_proj.weight shape: [9216, 3072] +158: model.layers.4.input_layernorm.weight shape: [3072] +159: model.layers.4.mlp.down_proj.weight shape: [3072, 8192] +160: model.layers.4.mlp.gate_up_proj.weight shape: [16384, 3072] +161: model.layers.4.post_attention_layernorm.weight shape: [3072] +162: model.layers.4.self_attn.o_proj.weight shape: [3072, 3072] +163: model.layers.4.self_attn.qkv_proj.weight shape: [9216, 3072] +164: model.layers.5.input_layernorm.weight shape: [3072] +165: model.layers.5.mlp.down_proj.weight shape: [3072, 8192] +166: model.layers.5.mlp.gate_up_proj.weight shape: [16384, 3072] +167: model.layers.5.post_attention_layernorm.weight shape: [3072] +168: model.layers.5.self_attn.o_proj.weight shape: [3072, 3072] +169: model.layers.5.self_attn.qkv_proj.weight shape: [9216, 3072] +170: model.layers.6.input_layernorm.weight shape: [3072] +171: model.layers.6.mlp.down_proj.weight shape: [3072, 8192] +172: model.layers.6.mlp.gate_up_proj.weight shape: [16384, 3072] +173: model.layers.6.post_attention_layernorm.weight shape: [3072] +174: model.layers.6.self_attn.o_proj.weight shape: [3072, 3072] +175: model.layers.6.self_attn.qkv_proj.weight shape: [9216, 3072] +176: model.layers.7.input_layernorm.weight shape: [3072] +177: model.layers.7.mlp.down_proj.weight shape: [3072, 8192] +178: model.layers.7.mlp.gate_up_proj.weight shape: [16384, 3072] +179: model.layers.7.post_attention_layernorm.weight shape: [3072] +180: model.layers.7.self_attn.o_proj.weight shape: [3072, 3072] +181: model.layers.7.self_attn.qkv_proj.weight shape: [9216, 3072] +182: model.layers.8.input_layernorm.weight shape: [3072] +183: model.layers.8.mlp.down_proj.weight shape: [3072, 8192] +184: model.layers.8.mlp.gate_up_proj.weight shape: [16384, 3072] +185: model.layers.8.post_attention_layernorm.weight shape: [3072] +186: model.layers.8.self_attn.o_proj.weight shape: [3072, 3072] +187: model.layers.8.self_attn.qkv_proj.weight shape: [9216, 3072] +188: model.layers.9.input_layernorm.weight shape: [3072] +189: model.layers.9.mlp.down_proj.weight shape: [3072, 8192] +190: model.layers.9.mlp.gate_up_proj.weight shape: [16384, 3072] +191: model.layers.9.post_attention_layernorm.weight shape: [3072] +192: model.layers.9.self_attn.o_proj.weight shape: [3072, 3072] +193: model.layers.9.self_attn.qkv_proj.weight shape: [9216, 3072] +194: model.norm.weight shape: [3072] diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.TokenizerTest.approved.txt b/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.TokenizerTest.approved.txt new file mode 100644 index 0000000000..442bde7bca --- /dev/null +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.TokenizerTest.approved.txt @@ -0,0 +1,5 @@ +1, 1815, 366, 3867, 5837, 304, 17545, 18240, 310, 9892, 16397, 322, 8338, 265, 29888, 21211, 29973 +1, 18585, 29991, 2266, 526, 777, 5837, 304, 17545, 9892, 16397, 322, 8338, 265, 29888, 21211, 4208, 29901, 29871, 29896, 29889, 10765, 1648, 322, 8338, 265, 29888, 9216, 10597, 347, 29901, 3164, 355, 9892, 16397, 322, 8338, 265, 29888, 21211, 4208, 411, 777, 27274, 322, 298, 4992, 29889, 29871, 29906, 29889, 10765, 1648, 322, 8338, 265, 29888, 9216, 4497, 328, 29901, 23478, 269, 506, 287, 9892, 16397, 322, 8338, 265, 29888, 21211, 4208, 411, 777, 454, 3712, 3623, 625, 322, 298, 4992, 29889 +1, 1724, 1048, 17069, 385, 29871, 29906, 29916, 718, 29871, 29941, 353, 29871, 29955, 6306, 29973 +1, 29871, 13, 3981, 304, 29871, 29941, 13 +1, 32010, 29871, 13, 3981, 304, 29871, 29941, 32007, 29871, 13, 32001 diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj b/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj new file mode 100644 index 0000000000..ed7f2c6342 --- /dev/null +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj @@ -0,0 +1,33 @@ + + + + net8.0 + enable + enable + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Test.cs b/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Test.cs new file mode 100644 index 0000000000..7d1f251a60 --- /dev/null +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Test.cs @@ -0,0 +1,93 @@ +using ApprovalTests.Namers; +using ApprovalTests.Reporters; +using ApprovalTests; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using static TorchSharp.torch; +using Xunit; +using TorchSharp; +using FluentAssertions; +using Microsoft.ML.TestFramework; +using Xunit.Abstractions; +using Microsoft.ML.Tokenizers; +using Microsoft.ML.GenAI.Core.Extension; +using System.Text.Json; +using Microsoft.ML.GenAI.Phi.Module; +namespace Microsoft.ML.GenAI.Phi.Tests; + +public class Phi2Test : BaseTestClass +{ + public Phi2Test(ITestOutputHelper output) : base(output) + { + } + + [Fact] + [UseReporter(typeof(DiffReporter))] + [UseApprovalSubdirectory("Approvals")] + public void LoadSafeTensorShapeTest() + { + torch.set_default_device("meta"); + var modelWeightFolder = "C:\\Users\\xiaoyuz\\source\\repos\\phi-2"; + var configName = "config.json"; + var config = Path.Join(modelWeightFolder, configName); + var modelConfig = JsonSerializer.Deserialize(File.ReadAllText(config)) ?? throw new ArgumentNullException(nameof(config)); + var model = new Phi2ForCasualLM(modelConfig); + var stateDictStr = model.PeekShape(); + Approvals.Verify(stateDictStr); + } + + //[Fact] + //[UseReporter(typeof(DiffReporter))] + //[UseApprovalSubdirectory("Approvals")] + //public async Task ForwardTest() + //{ + // // create dummy input id with 128 length and attention mask + // var device = "cuda"; + // var inputIds = torch.arange(128, dtype: ScalarType.Int64, device: device).unsqueeze(0); + // var attentionMask = torch.ones(1, 128, device: device); + // var modelWeightFolder = "C:\\Users\\xiaoyuz\\source\\repos\\phi-2"; + // var model = Phi2ForCasualLM.FromPretrained(modelWeightFolder, torchDtype: ScalarType.BFloat16, checkPointName: "model.safetensors.index.json", device: "cuda"); + // var input = new CasualLMModelInput(inputIds, attentionMask, past_key_values_length: 0); + // var output = model.forward(input); + // var outputTokenIds = output.last_hidden_state; + // var outputLogits = output.logits; + + // var outputTokenIdsStr = outputTokenIds.Peek("output"); + // var outputLogitsStr = outputLogits.Peek("logits"); + + // var sb = new StringBuilder(); + // sb.AppendLine(outputTokenIdsStr); + // sb.AppendLine(outputLogitsStr); + + // Approvals.Verify(sb.ToString()); + //} + + [Fact] + [UseReporter(typeof(DiffReporter))] + [UseApprovalSubdirectory("Approvals")] + public void TokenizerTest() + { + var modelWeightFolder = "C:\\Users\\xiaoyuz\\source\\repos\\phi-2"; + var tokenizer = Tokenizer.CreatePhi2(modelWeightFolder, addBeginOfSentence: true); + tokenizer.EndOfSentenceId.Should().Be(50256); + tokenizer.BeginningOfSentenceId.Should().Be(50256); + var messages = new string[] + { + "Can you provide ways to eat combinations of bananas and dragonfruits?", + "Sure! Here are some ways to eat bananas and dragonfruits together: 1. Banana and dragonfruit smoothie: Blend bananas and dragonfruits together with some milk and honey. 2. Banana and dragonfruit salad: Mix sliced bananas and dragonfruits together with some lemon juice and honey.", + "What about solving an 2x + 3 = 7 equation?" + }; + var sb = new StringBuilder(); + foreach (var message in messages) + { + var tokenized = tokenizer.EncodeToIds(message, true, false); + var tokenizedStr = string.Join(", ", tokenized.Select(x => x.ToString())); + + sb.AppendLine(tokenizedStr); + } + Approvals.Verify(sb.ToString()); + } +} diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs b/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs new file mode 100644 index 0000000000..3965a3d168 --- /dev/null +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs @@ -0,0 +1,127 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Text; +using System.Text.Json; +using ApprovalTests; +using ApprovalTests.Namers; +using ApprovalTests.Reporters; +using FluentAssertions; +using Microsoft.ML.GenAI.Core.Extension; +using Microsoft.ML.TestFramework; +using TorchSharp; +using Xunit; +using Xunit.Abstractions; +using static TorchSharp.torch; + +namespace Microsoft.ML.GenAI.Phi.Tests; + +public class Phi3Tests : BaseTestClass +{ + public Phi3Tests(ITestOutputHelper output) : base(output) + { + torch.set_default_device("meta"); + } + + [Fact] + [UseReporter(typeof(DiffReporter))] + [UseApprovalSubdirectory("Approvals")] + public void Phi3Mini4KShapeTest() + { + var modelWeightFolder = "C:\\Users\\xiaoyuz\\source\\repos\\Phi-3-mini-4k-instruct"; + var config = Path.Join(modelWeightFolder, "config.json"); + var modelConfig = JsonSerializer.Deserialize(File.ReadAllText(config)) ?? throw new ArgumentNullException(nameof(config)); + var model = new Phi3ForCasualLM(modelConfig); + var stateDictStr = model.PeekShape(); + Approvals.Verify(stateDictStr); + } + + [Fact] + [UseReporter(typeof(DiffReporter))] + [UseApprovalSubdirectory("Approvals")] + public void Phi3Medium4KShapeTest() + { + var modelWeightFolder = "C:\\Users\\xiaoyuz\\source\\repos\\Phi-3-medium-4k-instruct"; + var config = Path.Join(modelWeightFolder, "config.json"); + var modelConfig = JsonSerializer.Deserialize(File.ReadAllText(config)) ?? throw new ArgumentNullException(nameof(config)); + var model = new Phi3ForCasualLM(modelConfig); + var stateDictStr = model.PeekShape(); + Approvals.Verify(stateDictStr); + } + + + [Fact] + [UseReporter(typeof(DiffReporter))] + [UseApprovalSubdirectory("Approvals")] + public void Phi3Medium128KShapeTest() + { + var modelWeightFolder = "C:\\Users\\xiaoyuz\\source\\repos\\Phi-3-medium-128k-instruct"; + var config = Path.Join(modelWeightFolder, "config.json"); + var modelConfig = JsonSerializer.Deserialize(File.ReadAllText(config)) ?? throw new ArgumentNullException(nameof(config)); + var model = new Phi3ForCasualLM(modelConfig); + var stateDictStr = model.PeekShape(); + Approvals.Verify(stateDictStr); + } + + [Fact] + [UseReporter(typeof(DiffReporter))] + [UseApprovalSubdirectory("Approvals")] + public void Phi3Mini128KShapeTest() + { + var modelWeightFolder = "C:\\Users\\xiaoyuz\\source\\repos\\Phi-3-mini-128k-instruct"; + var config = Path.Join(modelWeightFolder, "config.json"); + var modelConfig = JsonSerializer.Deserialize(File.ReadAllText(config)) ?? throw new ArgumentNullException(nameof(config)); + var model = new Phi3ForCasualLM(modelConfig); + var stateDictStr = model.PeekShape(); + Approvals.Verify(stateDictStr); + } + + [Fact] + [UseReporter(typeof(DiffReporter))] + [UseApprovalSubdirectory("Approvals")] + public void Phi3Mini128KLayerSizeTest() + { + var dtype = ScalarType.BFloat16; + var modelWeightFolder = "C:\\Users\\xiaoyuz\\source\\repos\\Phi-3-mini-128k-instruct"; + var config = Path.Join(modelWeightFolder, "config.json"); + var modelConfig = JsonSerializer.Deserialize(File.ReadAllText(config)) ?? throw new ArgumentNullException(nameof(config)); + modelConfig.DType = dtype; + var model = new Phi3ForCasualLM(modelConfig); + + var size = model.GetSizeForEachDynamicLayerInBytes(); + // convert size to MB + var sizeInMB = size.ToDictionary(x => x.Key, x => x.Value * 1.0f / 1024 / 1024); + + var json = JsonSerializer.Serialize(sizeInMB, new JsonSerializerOptions { WriteIndented = true }); + Approvals.Verify(json); + } + + [Fact] + [UseReporter(typeof(DiffReporter))] + [UseApprovalSubdirectory("Approvals")] + public void TokenizerTest() + { + var modelWeightFolder = "C:\\Users\\xiaoyuz\\source\\repos\\Phi-3-mini-4k-instruct"; + var tokenizer = LLama2Tokenizer.FromPretrained(modelWeightFolder); + tokenizer.BosId.Should().Be(1); + tokenizer.EosId.Should().Be(2); + var messages = new string[] + { + "Can you provide ways to eat combinations of bananas and dragonfruits?", + "Sure! Here are some ways to eat bananas and dragonfruits together: 1. Banana and dragonfruit smoothie: Blend bananas and dragonfruits together with some milk and honey. 2. Banana and dragonfruit salad: Mix sliced bananas and dragonfruits together with some lemon juice and honey.", + "What about solving an 2x + 3 = 7 equation?", + "\nCount to 3\n", + "<|user|>\nCount to 3<|end|>\n<|assistant|>", + }; + var sb = new StringBuilder(); + foreach (var message in messages) + { + var tokenized = tokenizer.Encode(message, true, false); + var tokenizedStr = string.Join(", ", tokenized.Select(x => x.ToString())); + + sb.AppendLine(tokenizedStr); + } + Approvals.Verify(sb.ToString()); + } +} From 1493ebae3ca5d7b91a3d4ac12145723354f4aec0 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Wed, 26 Jun 2024 17:30:38 -0700 Subject: [PATCH 02/41] formatter --- .../Extension/CausalLMPipelineExtension.cs | 2 +- .../Microsoft.ML.GenAI.Core.csproj | 2 +- .../Module/GenAILinear.cs | 6 ++++- .../Module/NewGELUActivation.cs | 6 ++++- .../Module/Phi2Attention.cs | 6 ++++- .../Module/Phi2DecoderLayer.cs | 6 ++++- src/Microsoft.ML.GenAI.Phi/Module/Phi2MLP.cs | 6 ++++- .../Module/Phi2Model.cs | 6 ++++- .../Module/Phi2RotaryEmbedding.cs | 6 ++++- .../Module/Phi3Attention.cs | 6 ++++- .../Module/Phi3DecoderLayer.cs | 6 ++++- src/Microsoft.ML.GenAI.Phi/Module/Phi3MLP.cs | 10 +++++--- .../Module/Phi3RMSNorm.cs | 8 +++++-- .../Module/Phi3RotaryEmbedding.cs | 6 ++++- .../Module/Phi3SuScaledRotaryEmbedding.cs | 6 ++++- src/Microsoft.ML.GenAI.Phi/Phi2/Config.cs | 6 ++++- .../Phi2/Phi2ForCasualLM.cs | 6 ++++- src/Microsoft.ML.GenAI.Phi/Phi3/Config.cs | 6 ++++- .../Phi3/Phi3ForCasualLM.cs | 6 ++++- .../Phi3/Phi3Tokenzier.cs | 6 ++++- src/Microsoft.ML.GenAI.Phi/Utils.cs | 8 +++++-- test/Microsoft.ML.GenAI.Phi.Tests/Phi2Test.cs | 24 +++++++++++-------- 22 files changed, 115 insertions(+), 35 deletions(-) diff --git a/src/Microsoft.ML.GenAI.Core/Extension/CausalLMPipelineExtension.cs b/src/Microsoft.ML.GenAI.Core/Extension/CausalLMPipelineExtension.cs index 3a1041ee8a..b9d6bb3f1c 100644 --- a/src/Microsoft.ML.GenAI.Core/Extension/CausalLMPipelineExtension.cs +++ b/src/Microsoft.ML.GenAI.Core/Extension/CausalLMPipelineExtension.cs @@ -7,8 +7,8 @@ using System.Linq; using System.Text; using System.Threading.Tasks; -using static TorchSharp.torch; using TorchSharp; +using static TorchSharp.torch; namespace Microsoft.ML.GenAI.Core.Extension; diff --git a/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj b/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj index 9387bfbabe..ad4520e133 100644 --- a/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj +++ b/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj @@ -1,7 +1,7 @@  - net8.0 + net8.0 false enable preview diff --git a/src/Microsoft.ML.GenAI.Core/Module/GenAILinear.cs b/src/Microsoft.ML.GenAI.Core/Module/GenAILinear.cs index c5319ffddf..d206c0dfa2 100644 --- a/src/Microsoft.ML.GenAI.Core/Module/GenAILinear.cs +++ b/src/Microsoft.ML.GenAI.Core/Module/GenAILinear.cs @@ -1,4 +1,8 @@ -using TorchSharp; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using TorchSharp; using static TorchSharp.torch; namespace Microsoft.ML.GenAI; diff --git a/src/Microsoft.ML.GenAI.Core/Module/NewGELUActivation.cs b/src/Microsoft.ML.GenAI.Core/Module/NewGELUActivation.cs index a20ad47bd6..4c46e53104 100644 --- a/src/Microsoft.ML.GenAI.Core/Module/NewGELUActivation.cs +++ b/src/Microsoft.ML.GenAI.Core/Module/NewGELUActivation.cs @@ -1,4 +1,8 @@ -using System; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; using TorchSharp; using static TorchSharp.torch; diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi2Attention.cs b/src/Microsoft.ML.GenAI.Phi/Module/Phi2Attention.cs index 7bb45bef3f..918ae7c99b 100644 --- a/src/Microsoft.ML.GenAI.Phi/Module/Phi2Attention.cs +++ b/src/Microsoft.ML.GenAI.Phi/Module/Phi2Attention.cs @@ -1,4 +1,8 @@ -using System.Diagnostics.Contracts; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Diagnostics.Contracts; using TorchSharp; using TorchSharp.Modules; using static TorchSharp.torch; diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi2DecoderLayer.cs b/src/Microsoft.ML.GenAI.Phi/Module/Phi2DecoderLayer.cs index f1f87ee079..7931e32b79 100644 --- a/src/Microsoft.ML.GenAI.Phi/Module/Phi2DecoderLayer.cs +++ b/src/Microsoft.ML.GenAI.Phi/Module/Phi2DecoderLayer.cs @@ -1,4 +1,8 @@ -using TorchSharp; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using TorchSharp; using TorchSharp.Modules; using static TorchSharp.torch; diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi2MLP.cs b/src/Microsoft.ML.GenAI.Phi/Module/Phi2MLP.cs index 8d16bbb152..384d012e22 100644 --- a/src/Microsoft.ML.GenAI.Phi/Module/Phi2MLP.cs +++ b/src/Microsoft.ML.GenAI.Phi/Module/Phi2MLP.cs @@ -1,4 +1,8 @@ -using TorchSharp; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using TorchSharp; using TorchSharp.Modules; using static TorchSharp.torch; diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi2Model.cs b/src/Microsoft.ML.GenAI.Phi/Module/Phi2Model.cs index 2d7ee2d997..b96e0409f9 100644 --- a/src/Microsoft.ML.GenAI.Phi/Module/Phi2Model.cs +++ b/src/Microsoft.ML.GenAI.Phi/Module/Phi2Model.cs @@ -1,4 +1,8 @@ -using System.Diagnostics.Contracts; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Diagnostics.Contracts; using TorchSharp; using TorchSharp.Modules; using static TorchSharp.torch; diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi2RotaryEmbedding.cs b/src/Microsoft.ML.GenAI.Phi/Module/Phi2RotaryEmbedding.cs index ab14c9bb6a..a21ed4959e 100644 --- a/src/Microsoft.ML.GenAI.Phi/Module/Phi2RotaryEmbedding.cs +++ b/src/Microsoft.ML.GenAI.Phi/Module/Phi2RotaryEmbedding.cs @@ -1,4 +1,8 @@ -using TorchSharp; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using TorchSharp; using static TorchSharp.torch; namespace Microsoft.ML.GenAI.Phi.Module; diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi3Attention.cs b/src/Microsoft.ML.GenAI.Phi/Module/Phi3Attention.cs index d1ffd970d5..c51b0eef0b 100644 --- a/src/Microsoft.ML.GenAI.Phi/Module/Phi3Attention.cs +++ b/src/Microsoft.ML.GenAI.Phi/Module/Phi3Attention.cs @@ -1,4 +1,8 @@ -using System; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; using System.Collections.Generic; using System.Diagnostics.Contracts; using System.Linq; diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi3DecoderLayer.cs b/src/Microsoft.ML.GenAI.Phi/Module/Phi3DecoderLayer.cs index 55a4700db9..399cd25646 100644 --- a/src/Microsoft.ML.GenAI.Phi/Module/Phi3DecoderLayer.cs +++ b/src/Microsoft.ML.GenAI.Phi/Module/Phi3DecoderLayer.cs @@ -1,9 +1,13 @@ -using Microsoft.ML.GenAI.Core; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks; +using Microsoft.ML.GenAI.Core; using TorchSharp.Modules; using static TorchSharp.torch; diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi3MLP.cs b/src/Microsoft.ML.GenAI.Phi/Module/Phi3MLP.cs index abec0d78cf..752ea9dd2b 100644 --- a/src/Microsoft.ML.GenAI.Phi/Module/Phi3MLP.cs +++ b/src/Microsoft.ML.GenAI.Phi/Module/Phi3MLP.cs @@ -1,11 +1,15 @@ -using System; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks; -using static TorchSharp.torch; -using TorchSharp.Modules; using TorchSharp; +using TorchSharp.Modules; +using static TorchSharp.torch; namespace Microsoft.ML.GenAI.Phi.Module; #pragma warning disable MSML_GeneralName // This name should be PascalCased diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi3RMSNorm.cs b/src/Microsoft.ML.GenAI.Phi/Module/Phi3RMSNorm.cs index 23cfab24ba..e8c847268e 100644 --- a/src/Microsoft.ML.GenAI.Phi/Module/Phi3RMSNorm.cs +++ b/src/Microsoft.ML.GenAI.Phi/Module/Phi3RMSNorm.cs @@ -1,11 +1,15 @@ -using System; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks; -using static TorchSharp.torch; using TorchSharp; using TorchSharp.Modules; +using static TorchSharp.torch; namespace Microsoft.ML.GenAI.Phi.Module; diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi3RotaryEmbedding.cs b/src/Microsoft.ML.GenAI.Phi/Module/Phi3RotaryEmbedding.cs index 226d9b8d14..9b04a301d6 100644 --- a/src/Microsoft.ML.GenAI.Phi/Module/Phi3RotaryEmbedding.cs +++ b/src/Microsoft.ML.GenAI.Phi/Module/Phi3RotaryEmbedding.cs @@ -1,4 +1,8 @@ -using TorchSharp; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using TorchSharp; using static TorchSharp.torch; namespace Microsoft.ML.GenAI.Phi.Module; diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi3SuScaledRotaryEmbedding.cs b/src/Microsoft.ML.GenAI.Phi/Module/Phi3SuScaledRotaryEmbedding.cs index 7084c15839..ce0e70b686 100644 --- a/src/Microsoft.ML.GenAI.Phi/Module/Phi3SuScaledRotaryEmbedding.cs +++ b/src/Microsoft.ML.GenAI.Phi/Module/Phi3SuScaledRotaryEmbedding.cs @@ -1,4 +1,8 @@ -using System; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; using System.Collections.Generic; using System.Linq; using System.Text; diff --git a/src/Microsoft.ML.GenAI.Phi/Phi2/Config.cs b/src/Microsoft.ML.GenAI.Phi/Phi2/Config.cs index 2727321f6b..cafafba1a6 100644 --- a/src/Microsoft.ML.GenAI.Phi/Phi2/Config.cs +++ b/src/Microsoft.ML.GenAI.Phi/Phi2/Config.cs @@ -1,4 +1,8 @@ -using System; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; using System.Collections.Generic; using System.Linq; using System.Text; diff --git a/src/Microsoft.ML.GenAI.Phi/Phi2/Phi2ForCasualLM.cs b/src/Microsoft.ML.GenAI.Phi/Phi2/Phi2ForCasualLM.cs index af65f00ff2..efb3f23de9 100644 --- a/src/Microsoft.ML.GenAI.Phi/Phi2/Phi2ForCasualLM.cs +++ b/src/Microsoft.ML.GenAI.Phi/Phi2/Phi2ForCasualLM.cs @@ -1,4 +1,8 @@ -using System.CodeDom; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.CodeDom; using System.Text.Json; using System.Text.Json.Serialization; using Microsoft.ML.GenAI.Core; diff --git a/src/Microsoft.ML.GenAI.Phi/Phi3/Config.cs b/src/Microsoft.ML.GenAI.Phi/Phi3/Config.cs index 2e5f755d95..023819b0bf 100644 --- a/src/Microsoft.ML.GenAI.Phi/Phi3/Config.cs +++ b/src/Microsoft.ML.GenAI.Phi/Phi3/Config.cs @@ -1,4 +1,8 @@ -using System; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; using System.Collections.Generic; using System.Linq; using System.Text; diff --git a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ForCasualLM.cs b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ForCasualLM.cs index 9992c92c30..8ab7ecc652 100644 --- a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ForCasualLM.cs +++ b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ForCasualLM.cs @@ -1,4 +1,8 @@ -using System; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; using System.Collections.Generic; using System.Linq; using System.Text; diff --git a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3Tokenzier.cs b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3Tokenzier.cs index 8ef4f6fbde..387cefc63c 100644 --- a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3Tokenzier.cs +++ b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3Tokenzier.cs @@ -1,4 +1,7 @@ -using Microsoft.ML.Tokenizers; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + using System; using System.Collections.Generic; using System.Linq; @@ -7,6 +10,7 @@ using System.Text.Json; using System.Text.RegularExpressions; using System.Threading.Tasks; +using Microsoft.ML.Tokenizers; namespace Microsoft.ML.GenAI.Phi; public interface ITokenizer diff --git a/src/Microsoft.ML.GenAI.Phi/Utils.cs b/src/Microsoft.ML.GenAI.Phi/Utils.cs index c4a05bbe40..c6f195d4b0 100644 --- a/src/Microsoft.ML.GenAI.Phi/Utils.cs +++ b/src/Microsoft.ML.GenAI.Phi/Utils.cs @@ -1,10 +1,14 @@ -using System; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks; -using TorchSharp.Modules; using TorchSharp; +using TorchSharp.Modules; using static TorchSharp.torch; using static TorchSharp.torch.nn; namespace Microsoft.ML.GenAI.Phi; diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Test.cs b/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Test.cs index 7d1f251a60..2666bd3175 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Test.cs +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Test.cs @@ -1,21 +1,25 @@ -using ApprovalTests.Namers; -using ApprovalTests.Reporters; -using ApprovalTests; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + using System; using System.Collections.Generic; using System.Linq; using System.Text; +using System.Text.Json; using System.Threading.Tasks; -using static TorchSharp.torch; -using Xunit; -using TorchSharp; +using ApprovalTests; +using ApprovalTests.Namers; +using ApprovalTests.Reporters; using FluentAssertions; -using Microsoft.ML.TestFramework; -using Xunit.Abstractions; -using Microsoft.ML.Tokenizers; using Microsoft.ML.GenAI.Core.Extension; -using System.Text.Json; using Microsoft.ML.GenAI.Phi.Module; +using Microsoft.ML.TestFramework; +using Microsoft.ML.Tokenizers; +using TorchSharp; +using Xunit; +using Xunit.Abstractions; +using static TorchSharp.torch; namespace Microsoft.ML.GenAI.Phi.Tests; public class Phi2Test : BaseTestClass From 972c7c9207e98b27dc3bcd1ec2fd5f43c07fac38 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Thu, 27 Jun 2024 12:06:32 -0700 Subject: [PATCH 03/41] refactor Phi3Tokenizer --- .../Phi3/Phi3Tokenzier.cs | 244 +++++++----------- .../Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs | 4 +- 2 files changed, 97 insertions(+), 151 deletions(-) diff --git a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3Tokenzier.cs b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3Tokenzier.cs index 387cefc63c..8f29f45750 100644 --- a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3Tokenzier.cs +++ b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3Tokenzier.cs @@ -2,35 +2,16 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Collections.Generic; -using System.Linq; -using System.Reflection; using System.Text; -using System.Text.Json; using System.Text.RegularExpressions; -using System.Threading.Tasks; using Microsoft.ML.Tokenizers; -namespace Microsoft.ML.GenAI.Phi; -public interface ITokenizer -{ - public int BosId { get; } - - public int EosId { get; } - - public string Decode(int[] input); - - public int[] Encode(string input, bool bos, bool eos); -} - -/// -/// Copied from https://github.com/LittleLittleCloud/Torchsharp-llama/blob/main/ITokenizer.cs -/// -public class LLama2Tokenizer : ITokenizer +public class Phi3Tokenizer : Tokenizer { private readonly SentencePieceBpe _tokenizer; private readonly bool _addPrecedingSpace; + private readonly bool _addBeginningOfSentence; + private readonly bool _addEndOfSentence; private const string SystemSymbol = "<|system|>"; private const string UserSymbol = "<|user|>"; private const string AssistantSymbol = "<|assistant|>"; @@ -47,134 +28,25 @@ public class LLama2Tokenizer : ITokenizer { EndSymbol, EndSymbolId } }; - public LLama2Tokenizer(string modelPath, bool addPrecedingSpace = true) + public Phi3Tokenizer(string modelPath, + bool addPrecedingSpace = true, + bool addBeginningOfSentence = true, + bool addEndOfSentence = true) { var modelStream = File.OpenRead(modelPath); this._addPrecedingSpace = addPrecedingSpace; + this._addBeginningOfSentence = addBeginningOfSentence; + this._addEndOfSentence = addEndOfSentence; this._tokenizer = (SentencePieceBpe)Tokenizer.CreateLlama(modelStream, false, false); - - // use reflection to set the readonly ByteFallback property to false - //var backingField = typeof(SentencePieceBpe).GetField("k__BackingField", BindingFlags.NonPublic | BindingFlags.Instance); - //backingField.SetValue(this.tokenizer, false); } - //public LLama2Tokenizer(string vocabPath, string mergesPath, bool addPrecedingSpace = true, int padToken = -1, int startToken = 1, int endToken = 2) - //{ - // this.BosId = startToken; - // this.EosId = endToken; - // this.addPrecedingSpace = addPrecedingSpace; - // this.PadId = padToken; - // var bpe = new Bpe(vocabPath, mergesPath); - // this.tokenizer = new Tokenizer(bpe, preTokenizer: new PreTokenizer(), normalizer: new Norm()); - // var decoder = new TokenizeDecoder(this.tokenizer.Model.IdToToken(this.BosId)!, this.tokenizer.Model.IdToToken(this.EosId)!); - // this.tokenizer.Decoder = decoder; - //} - - //public LLama2Tokenizer(Dictionary vocab, List merges, bool addPrecedingSpace = true, int padToken = -1, int startToken = 1, int endToken = 2) - //{ - // this.BosId = startToken; - // this.EosId = endToken; - // this.addPrecedingSpace = addPrecedingSpace; - // this.PadId = padToken; - // // save vocab to vocab-temp.json - // var vocabTempPath = "vocab-temp.json"; - // var json = JsonSerializer.Serialize(vocab); - // File.WriteAllText(vocabTempPath, json); - - // // save merges to merges-temp.txt - // var mergesTempPath = "merges-temp.txt"; - // // filter out merges that contain newline character because it will cause error in BPE - // merges = merges.Where(x => !x.Contains('\r')).ToList(); - // File.WriteAllLines(mergesTempPath, merges); - - // var bpe = new Bpe(vocabTempPath, mergesTempPath); - // this.tokenizer = new Tokenizer(bpe, preTokenizer: new PreTokenizer(), normalizer: new Norm()); - // var decoder = new TokenizeDecoder(this.tokenizer.Model.IdToToken(this.BosId)!, this.tokenizer.Model.IdToToken(this.EosId)!); - // this.tokenizer.Decoder = decoder; - - // // delete temp files - // File.Delete(vocabTempPath); - // File.Delete(mergesTempPath); - //} - - public static LLama2Tokenizer FromPretrained( + + public static Phi3Tokenizer FromPretrained( string folder, string modelName = "tokenizer.model") { - return new LLama2Tokenizer(Path.Combine(folder, modelName)); + return new Phi3Tokenizer(Path.Combine(folder, modelName)); } - //public static LLama2Tokenizer FromPretrained( - // string folder, - // string tokenizerJsonPath = "tokenizer.json", - // string specialTokensMapPath = "special_tokens_map.json" - //) - //{ - // tokenizerJsonPath = Path.Combine(folder, tokenizerJsonPath); - // var json = File.ReadAllText(tokenizerJsonPath); - // var jsonDocument = JsonDocument.Parse(json); - // // vocab: .model.vocab - // var vocabNode = jsonDocument.RootElement.GetProperty("model").GetProperty("vocab"); - - // // to Dictionary - // var vocab = new Dictionary(); - // foreach (var item in vocabNode.EnumerateObject()) - // { - // vocab[item.Name] = item.Value.GetInt32(); - // } - - // // added tokens: .added_tokens - // var addedTokensNode = jsonDocument.RootElement.GetProperty("added_tokens"); - // foreach (var item in addedTokensNode.EnumerateArray()) - // { - // // get id from item.id - // var id = item.GetProperty("id").GetInt32(); - // var content = item.GetProperty("content").GetString()!; - // vocab[content] = id; - // } - - // // merges: .model.merges - // var mergesNode = jsonDocument.RootElement.GetProperty("model").GetProperty("merges"); - // // merges: List - // var merges = new List(); - // foreach (var item in mergesNode.EnumerateArray()) - // { - // merges.Add(item.GetString()!); - // } - - // int startToken = 1, endToken = 2, padToken = -1; - // var specialTokenJsonPath = Path.Combine(folder, specialTokensMapPath); - // if (File.Exists(specialTokenJsonPath)) - // { - // var specialTokenJson = File.ReadAllText(specialTokenJsonPath); - // var specialTokenMapDocument = JsonDocument.Parse(specialTokenJson); - - // // retrieve bos_token, eos_token, pad_token if exists - // if (specialTokenMapDocument.RootElement.TryGetProperty("bos_token", out var bosTokenNode)) - // { - // var bos_token_content = bosTokenNode.GetProperty("content").GetString()!; - // startToken = vocab[bos_token_content]; - // } - - // if (specialTokenMapDocument.RootElement.TryGetProperty("eos_token", out var eosTokenNode)) - // { - // var eos_token_content = eosTokenNode.GetProperty("content").GetString()!; - // endToken = vocab[eos_token_content]; - // } - - // if (specialTokenMapDocument.RootElement.TryGetProperty("pad_token", out var padTokenNode)) - // { - // var pad_token_content = padTokenNode.GetProperty("content").GetString()!; - // padToken = vocab[pad_token_content]; - // } - // } - - // return new LLama2Tokenizer(vocab, merges, padToken: padToken, addPrecedingSpace: false, startToken: startToken, endToken: endToken); - //} - - //public int VocabSize => this.tokenizer..GetVocabSize(); - - public int PadId { get => this._tokenizer.UnknownId; } - public int BosId { get => this._tokenizer.BeginningOfSentenceId; } public int EosId { get => this._tokenizer.EndOfSentenceId; } @@ -190,8 +62,41 @@ public string Decode(int[] input) return str; } - public int[] Encode(string input, bool bos, bool eos) + public override IReadOnlyList Encode(ReadOnlySpan text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) { + var tokens = new List(); + var normalizedText = new StringBuilder(); + var input = text.ToString(); + + // step 1: + // replace all special tokens to + var re = new Regex($"{SystemSymbol.Replace("|", "\\|")}|{UserSymbol.Replace("|", "\\|")}|{AssistantSymbol.Replace("|", "\\|")}|{EndSymbol.Replace("|", "\\|")}"); + var matches = re.Matches(input); + var matchesList = new List(); + foreach (Match match in matches) + { + // replace the first special tokens with + var specialToken = match.Value; + var index = input.IndexOf(specialToken); + var subString = input.Substring(0, index); + var subTokens = this._tokenizer.Encode(subString, out var subNormalizeString, addBeginningOfSentence: false, addEndOfSentence: false, considerPreTokenization: considerPreTokenization, considerNormalization: considerNormalization).ToArray(); + normalizedText.Append(subNormalizeString); + tokens.AddRange(subTokens); + tokens.Add(new Token(this._specialTokenMap[specialToken], specialToken, (index, specialToken.Length))); + input = input.Remove(0, index + specialToken.Length); + } + + tokens.AddRange(this._tokenizer.Encode(input, out var normailzeString, addBeginningOfSentence: false, addEndOfSentence: false, considerPreTokenization: considerPreTokenization, considerNormalization: considerNormalization).ToArray()); + + normalizedText.Append(normailzeString); + normalizedString = normalizedText.ToString(); + + return tokens.ToArray(); + } + + public override IReadOnlyList EncodeToIds(ReadOnlySpan text, bool considerPreTokenization = true, bool considerNormalization = true) + { + var input = text.ToString(); // step 1: // replace all special tokens to var re = new Regex($"{SystemSymbol.Replace("|", "\\|")}|{UserSymbol.Replace("|", "\\|")}|{AssistantSymbol.Replace("|", "\\|")}|{EndSymbol.Replace("|", "\\|")}"); @@ -204,23 +109,64 @@ public int[] Encode(string input, bool bos, bool eos) var specialToken = match.Value; var index = input.IndexOf(specialToken); var subString = input.Substring(0, index); - var subTokens = this._tokenizer.EncodeToIds(subString, addBeginningOfSentence: false, addEndOfSentence: false).ToArray(); + var subTokens = this._tokenizer.EncodeToIds(subString, addBeginningOfSentence: false, addEndOfSentence: false, considerPreTokenization: considerPreTokenization, considerNormalization: considerNormalization).ToArray(); tokens.AddRange(subTokens); tokens.Add(this._specialTokenMap[specialToken]); input = input.Remove(0, index + specialToken.Length); } - tokens.AddRange(this._tokenizer.EncodeToIds(input, addBeginningOfSentence: false, addEndOfSentence: false).ToArray()); - if (bos) + tokens.AddRange(this._tokenizer.EncodeToIds(input, addBeginningOfSentence: false, addEndOfSentence: false, considerPreTokenization: considerPreTokenization, considerNormalization: considerNormalization).ToArray()); + + return this._addBeginningOfSentence ? new int[] { this.BosId }.Concat(tokens).ToArray() : tokens.ToArray(); + } + + public override IReadOnlyList EncodeToIds(ReadOnlySpan text, int maxTokenCount, out string? normalizedText, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true) + { + var tokens = this.Encode(text, out normalizedText, considerPreTokenization, considerNormalization); + + var tokenIds = tokens.Select(x => x.Id).ToArray(); + + textLength = normalizedText?.Length ?? 0; + + return tokenIds.Length > maxTokenCount ? tokenIds.Take(maxTokenCount).ToArray() : tokenIds; + } + + public override int CountTokens(ReadOnlySpan text, bool considerPreTokenization = true, bool considerNormalization = true) + { + var tokens = this.EncodeToIds(text, considerPreTokenization, considerNormalization); + + return tokens.Count; + } + + public override int IndexOfTokenCount(ReadOnlySpan text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) + { + return _tokenizer.IndexOfTokenCount(text, maxTokenCount, out normalizedString, out tokenCount, considerPreTokenization, considerNormalization); + } + + public override int LastIndexOfTokenCount(ReadOnlySpan text, int maxTokenCount, out string? processedText, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) + { + return _tokenizer.LastIndexOfTokenCount(text, maxTokenCount, out processedText, out tokenCount, considerPreTokenization, considerNormalization); + } + + public override int? MapTokenToId(ReadOnlySpan token) + { + // check if token in special tokens + var tokenStr = token.ToString(); + if (_specialTokenMap.ContainsKey(tokenStr)) { - tokens.Insert(0, this.BosId); + return _specialTokenMap[tokenStr]; } - if (eos) + + return _tokenizer.MapTokenToId(token); + } + + public override string? MapIdToToken(int id) + { + if (_specialTokenMap.ContainsValue(id)) { - tokens.Add(this.EosId); + return _specialTokenMap.First(x => x.Value == id).Key; } - - return tokens.ToArray(); + return _tokenizer.MapIdToToken(id); } } diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs b/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs index 3965a3d168..fd204b2449 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs @@ -103,7 +103,7 @@ public void Phi3Mini128KLayerSizeTest() public void TokenizerTest() { var modelWeightFolder = "C:\\Users\\xiaoyuz\\source\\repos\\Phi-3-mini-4k-instruct"; - var tokenizer = LLama2Tokenizer.FromPretrained(modelWeightFolder); + var tokenizer = Phi3Tokenizer.FromPretrained(modelWeightFolder); tokenizer.BosId.Should().Be(1); tokenizer.EosId.Should().Be(2); var messages = new string[] @@ -117,7 +117,7 @@ public void TokenizerTest() var sb = new StringBuilder(); foreach (var message in messages) { - var tokenized = tokenizer.Encode(message, true, false); + var tokenized = tokenizer.EncodeToIds(message, considerPreTokenization: true); var tokenizedStr = string.Join(", ", tokenized.Select(x => x.ToString())); sb.AppendLine(tokenizedStr); From 349a4262ec76efeab892264bc7529179193bde50 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Thu, 27 Jun 2024 13:48:09 -0700 Subject: [PATCH 04/41] update --- .../Phi3/Phi3Tokenzier.cs | 77 +++++++++++++++---- .../Microsoft.ML.Tokenizers.csproj | 4 + .../Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs | 2 + .../Microsoft.ML.Tokenizers.Tests.csproj | 2 + 4 files changed, 70 insertions(+), 15 deletions(-) diff --git a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3Tokenzier.cs b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3Tokenzier.cs index 8f29f45750..dada1fa64a 100644 --- a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3Tokenzier.cs +++ b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3Tokenzier.cs @@ -51,17 +51,6 @@ public static Phi3Tokenizer FromPretrained( public int EosId { get => this._tokenizer.EndOfSentenceId; } - public string Decode(int[] input) - { - var str = this._tokenizer.Decode(input) ?? throw new Exception("Failed to decode"); - if (this._addPrecedingSpace) - { - str = str.TrimStart(); - } - - return str; - } - public override IReadOnlyList Encode(ReadOnlySpan text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) { var tokens = new List(); @@ -97,6 +86,10 @@ public override IReadOnlyList Encode(ReadOnlySpan text, out string? public override IReadOnlyList EncodeToIds(ReadOnlySpan text, bool considerPreTokenization = true, bool considerNormalization = true) { var input = text.ToString(); + if (this._addPrecedingSpace) + { + //input = " " + input; + } // step 1: // replace all special tokens to var re = new Regex($"{SystemSymbol.Replace("|", "\\|")}|{UserSymbol.Replace("|", "\\|")}|{AssistantSymbol.Replace("|", "\\|")}|{EndSymbol.Replace("|", "\\|")}"); @@ -105,17 +98,17 @@ public override IReadOnlyList EncodeToIds(ReadOnlySpan text, bool con var tokens = new List(); foreach (Match match in matches) { - // replace the first special tokens with var specialToken = match.Value; var index = input.IndexOf(specialToken); var subString = input.Substring(0, index); - var subTokens = this._tokenizer.EncodeToIds(subString, addBeginningOfSentence: false, addEndOfSentence: false, considerPreTokenization: considerPreTokenization, considerNormalization: considerNormalization).ToArray(); - tokens.AddRange(subTokens); + var subTokens = this._tokenizer.EncodeToIds(subString, addBeginningOfSentence: false, addEndOfSentence: false, considerPreTokenization: false, considerNormalization: true).ToArray(); + // remove the first sub Token as it will always be '_' + tokens.AddRange(subTokens.Skip(1)); tokens.Add(this._specialTokenMap[specialToken]); input = input.Remove(0, index + specialToken.Length); } - tokens.AddRange(this._tokenizer.EncodeToIds(input, addBeginningOfSentence: false, addEndOfSentence: false, considerPreTokenization: considerPreTokenization, considerNormalization: considerNormalization).ToArray()); + tokens.AddRange(this._tokenizer.EncodeToIds(input, addBeginningOfSentence: false, addEndOfSentence: false, considerPreTokenization: false, considerNormalization: true).ToArray()); return this._addBeginningOfSentence ? new int[] { this.BosId }.Concat(tokens).ToArray() : tokens.ToArray(); } @@ -148,6 +141,60 @@ public override int LastIndexOfTokenCount(ReadOnlySpan text, int maxTokenC return _tokenizer.LastIndexOfTokenCount(text, maxTokenCount, out processedText, out tokenCount, considerPreTokenization, considerNormalization); } + public override string? Decode(IEnumerable ids) + { + // step 1 + // replace all special token ids to ukn ids + var replacedIds = ids.SelectMany(id => + { + if (this._specialTokenMap.ContainsValue(id)) + { + var key = this._specialTokenMap.First(x => x.Value == id).Key; + var ids = this._tokenizer.EncodeToIds(key, false, false, false, false); + var recoverKey = this._tokenizer.Decode(ids) ?? throw new Exception("Failed to decode ids"); + return ids; + } + else + { + return new List { id }; + } + }); + + var str = this._tokenizer.Decode(replacedIds) ?? throw new Exception("Failed to decode ids"); + + return str; + + //var tokens = new List(); + //foreach (var id in ids) + //{ + // if (_specialTokenMap.ContainsValue(id)) + // { + // tokens.Add(_specialTokenMap.First(x => x.Value == id).Key); + // } + // else + // { + // tokens.Add(this._tokenizer.MapIdToToken(id) ?? throw new Exception("Failed to map id to token")); + // } + //} + + //if (this._addBeginningOfSentence) + //{ + // tokens = tokens[1..].ToList(); + //} + + //var str = string.Join("", tokens); + + //// replace Dummy with whitespace + //str = str.Replace(SentencePieceNormalizer.DummyPrefix, ' '); + + //if (this._addPrecedingSpace) + //{ + // str = str.TrimStart(' '); + //} + + //return str; + } + public override int? MapTokenToId(ReadOnlySpan token) { // check if token in special tokens diff --git a/src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj b/src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj index fbff32071e..a61041f8e1 100644 --- a/src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj +++ b/src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj @@ -97,6 +97,10 @@ + + + + x.ToString())); sb.AppendLine(tokenizedStr); diff --git a/test/Microsoft.ML.Tokenizers.Tests/Microsoft.ML.Tokenizers.Tests.csproj b/test/Microsoft.ML.Tokenizers.Tests/Microsoft.ML.Tokenizers.Tests.csproj index 802cae464a..7fb56e82aa 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/Microsoft.ML.Tokenizers.Tests.csproj +++ b/test/Microsoft.ML.Tokenizers.Tests/Microsoft.ML.Tokenizers.Tests.csproj @@ -48,4 +48,6 @@ + + \ No newline at end of file From 17b689af77151463cf3b24100c32d558f779c620 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Thu, 27 Jun 2024 14:36:49 -0700 Subject: [PATCH 05/41] add configuration for phi-series --- .../Microsoft.ML.GenAI.Phi.csproj | 4 + src/Microsoft.ML.GenAI.Phi/Phi2/Config.cs | 13 ++ src/Microsoft.ML.GenAI.Phi/Phi3/Config.cs | 34 ++++ .../Resource/Config/phi-2-config.json | 31 ++++ .../phi-3-medium-128k-instruct-config.json | 170 ++++++++++++++++++ .../phi-3-medium-4k-instruct-config.json | 36 ++++ .../phi-3-mini-128k-instruct-config.json | 140 +++++++++++++++ .../Config/phi-3-mini-4k-instruct-config.json | 36 ++++ src/Microsoft.ML.GenAI.Phi/Resource/README.md | 8 + src/Microsoft.ML.GenAI.Phi/Utils.cs | 18 +- ...sts.Phi3Mini128KLayerSizeTest.approved.txt | 64 +++---- test/Microsoft.ML.GenAI.Phi.Tests/Phi2Test.cs | 34 +--- .../Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs | 30 +--- 13 files changed, 529 insertions(+), 89 deletions(-) create mode 100644 src/Microsoft.ML.GenAI.Phi/Resource/Config/phi-2-config.json create mode 100644 src/Microsoft.ML.GenAI.Phi/Resource/Config/phi-3-medium-128k-instruct-config.json create mode 100644 src/Microsoft.ML.GenAI.Phi/Resource/Config/phi-3-medium-4k-instruct-config.json create mode 100644 src/Microsoft.ML.GenAI.Phi/Resource/Config/phi-3-mini-128k-instruct-config.json create mode 100644 src/Microsoft.ML.GenAI.Phi/Resource/Config/phi-3-mini-4k-instruct-config.json create mode 100644 src/Microsoft.ML.GenAI.Phi/Resource/README.md diff --git a/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj b/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj index e313da05ef..59f77e73c8 100644 --- a/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj +++ b/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj @@ -19,4 +19,8 @@ + + + + diff --git a/src/Microsoft.ML.GenAI.Phi/Phi2/Config.cs b/src/Microsoft.ML.GenAI.Phi/Phi2/Config.cs index cafafba1a6..fdba74ba77 100644 --- a/src/Microsoft.ML.GenAI.Phi/Phi2/Config.cs +++ b/src/Microsoft.ML.GenAI.Phi/Phi2/Config.cs @@ -6,6 +6,7 @@ using System.Collections.Generic; using System.Linq; using System.Text; +using System.Text.Json; using System.Text.Json.Serialization; using System.Threading.Tasks; using static TorchSharp.torch; @@ -38,6 +39,18 @@ public Phi2Config() this.Dtype = ScalarType.Float32; } + static Phi2Config() + { + var phi2ConfigContent = Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-2-config.json"); + var phi2Config = JsonSerializer.Deserialize(phi2ConfigContent) ?? throw new ArgumentNullException(nameof(phi2ConfigContent)); + Phi2 = phi2Config; + } + + /// + /// The default phi-2 configuration created from https://huggingface.co/microsoft/phi-2/blob/main/config.json. + /// + public static Phi2Config Phi2 { get; } + [JsonPropertyName("vocab_size")] public int VocabSize { get; set; } diff --git a/src/Microsoft.ML.GenAI.Phi/Phi3/Config.cs b/src/Microsoft.ML.GenAI.Phi/Phi3/Config.cs index 023819b0bf..def5ab3448 100644 --- a/src/Microsoft.ML.GenAI.Phi/Phi3/Config.cs +++ b/src/Microsoft.ML.GenAI.Phi/Phi3/Config.cs @@ -6,6 +6,7 @@ using System.Collections.Generic; using System.Linq; using System.Text; +using System.Text.Json; using System.Text.Json.Serialization; using System.Threading.Tasks; using static TorchSharp.torch; @@ -39,6 +40,39 @@ public Phi3Config() this.AttnImplementation = "eager"; } + static Phi3Config() + { + var phi3Mini4kInstructContent = Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-3-mini-4k-instruct-config.json"); + var phi3Mini128kInstructContent = Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-3-mini-128k-instruct-config.json"); + var phi3Medium4kInstructContent = Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-3-medium-4k-instruct-config.json"); + var phi3Medium128kInstructContent = Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-3-medium-128k-instruct-config.json"); + + Phi3Mini4kInstruct = JsonSerializer.Deserialize(phi3Mini4kInstructContent) ?? throw new ArgumentNullException(nameof(phi3Mini4kInstructContent)); + Phi3Mini128kInstruct = JsonSerializer.Deserialize(phi3Mini128kInstructContent) ?? throw new ArgumentNullException(nameof(phi3Mini128kInstructContent)); + Phi3Medium4kInstruct = JsonSerializer.Deserialize(phi3Medium4kInstructContent) ?? throw new ArgumentNullException(nameof(phi3Medium4kInstructContent)); + Phi3Medium128kInstruct = JsonSerializer.Deserialize(phi3Medium128kInstructContent) ?? throw new ArgumentNullException(nameof(phi3Medium128kInstructContent)); + } + + /// + /// The phi-3-mini-4k-instruct configuration created from https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json. + /// + public static Phi3Config Phi3Mini4kInstruct { get; } + + /// + /// The phi-3-medium-4k-instruct configuration created from https://huggingface.co/microsoft/Phi-3-medium-4k-instruct/blob/main/config.json. + /// + public static Phi3Config Phi3Medium4kInstruct { get; } + + /// + /// The phi-3-medium-128k-instruct configuration created from https://huggingface.co/microsoft/Phi-3-medium-128k-instruct/blob/main/config.json. + /// + public static Phi3Config Phi3Medium128kInstruct { get; } + + /// + /// The phi-3-mini-128k-instruct configuration created from https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/blob/main/config.json. + /// + public static Phi3Config Phi3Mini128kInstruct { get; } + [JsonPropertyName("vocab_size")] public int VocabSize { get; set; } diff --git a/src/Microsoft.ML.GenAI.Phi/Resource/Config/phi-2-config.json b/src/Microsoft.ML.GenAI.Phi/Resource/Config/phi-2-config.json new file mode 100644 index 0000000000..c3a5b1ce17 --- /dev/null +++ b/src/Microsoft.ML.GenAI.Phi/Resource/Config/phi-2-config.json @@ -0,0 +1,31 @@ +{ + "_name_or_path": "microsoft/phi-2", + "architectures": [ + "PhiForCausalLM" + ], + "attention_dropout": 0.0, + "bos_token_id": 50256, + "embd_pdrop": 0.0, + "eos_token_id": 50256, + "hidden_act": "gelu_new", + "hidden_size": 2560, + "initializer_range": 0.02, + "intermediate_size": 10240, + "layer_norm_eps": 1e-05, + "max_position_embeddings": 2048, + "model_type": "phi", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 32, + "partial_rotary_factor": 0.4, + "qk_layernorm": false, + "resid_pdrop": 0.1, + "rope_scaling": null, + "rope_theta": 10000.0, + "tie_word_embeddings": false, + "torch_dtype": "float16", + "transformers_version": "4.37.0", + "use_cache": true, + "vocab_size": 51200 + } + \ No newline at end of file diff --git a/src/Microsoft.ML.GenAI.Phi/Resource/Config/phi-3-medium-128k-instruct-config.json b/src/Microsoft.ML.GenAI.Phi/Resource/Config/phi-3-medium-128k-instruct-config.json new file mode 100644 index 0000000000..f058b81505 --- /dev/null +++ b/src/Microsoft.ML.GenAI.Phi/Resource/Config/phi-3-medium-128k-instruct-config.json @@ -0,0 +1,170 @@ +{ + "_name_or_path": "Phi-3-medium-128k-instruct", + "architectures": [ + "Phi3ForCausalLM" + ], + "attention_dropout": 0.0, + "auto_map": { + "AutoConfig": "configuration_phi3.Phi3Config", + "AutoModelForCausalLM": "modeling_phi3.Phi3ForCausalLM" + }, + "bos_token_id": 1, + "embd_pdrop": 0.0, + "eos_token_id": 32000, + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 17920, + "max_position_embeddings": 131072, + "model_type": "phi3", + "num_attention_heads": 40, + "num_hidden_layers": 40, + "num_key_value_heads": 10, + "original_max_position_embeddings": 4096, + "pad_token_id": null, + "resid_pdrop": 0.0, + "rms_norm_eps": 1e-05, + "rope_scaling": { + "long_factor": [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.25, + 1.25, + 1.5, + 2.0, + 2.75, + 5.75, + 5.75, + 6.5, + 9.25, + 11.0, + 13.25, + 19.25, + 19.75, + 19.75, + 21.25, + 21.5, + 26.5, + 30.0, + 33.75, + 35.25, + 38.5, + 42.0, + 42.25, + 46.0, + 47.0, + 50.0, + 50.5, + 51.0, + 52.0, + 52.75, + 53.75, + 54.75, + 57.0, + 57.25, + 58.5, + 59.25, + 59.5, + 62.0, + 62.5, + 62.75, + 63.25, + 63.25, + 63.25, + 63.75, + 64.0, + 64.0, + 64.25, + 64.5, + 64.5, + 65.0, + 65.0 + ], + "short_factor": [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.01, + 1.02, + 1.02, + 1.04, + 1.04, + 1.07, + 1.07, + 1.1, + 1.3000000000000003, + 1.3000000000000003, + 1.5000000000000004, + 1.5700000000000005, + 1.9000000000000008, + 2.3100000000000014, + 2.759999999999992, + 3.3899999999999784, + 3.9399999999999666, + 4.009999999999965, + 4.289999999999959, + 4.349999999999958, + 5.349999999999937, + 6.659999999999909, + 7.029999999999901, + 7.51999999999989, + 8.00999999999988, + 8.249999999999876, + 8.279999999999875, + 9.629999999999846, + 9.89999999999984, + 10.589999999999826, + 11.049999999999816, + 11.7899999999998, + 12.189999999999792, + 12.889999999999777, + 13.129999999999772, + 13.16999999999977, + 13.20999999999977, + 13.479999999999764, + 13.539999999999763, + 13.779999999999758, + 13.929999999999755, + 14.429999999999744, + 14.759999999999737, + 15.149999999999729, + 15.419999999999723, + 15.53999999999972, + 15.659999999999718, + 15.749999999999716, + 15.759999999999716, + 15.799999999999715, + 16.05999999999971, + 16.079999999999714, + 16.11999999999972, + 16.11999999999972, + 16.18999999999973, + 16.31999999999975, + 16.539999999999786, + 16.799999999999827 + ], + "type": "su" + }, + "rope_theta": 10000.0, + "sliding_window": 131072, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.39.3", + "use_cache": true, + "attention_bias": false, + "vocab_size": 32064 +} diff --git a/src/Microsoft.ML.GenAI.Phi/Resource/Config/phi-3-medium-4k-instruct-config.json b/src/Microsoft.ML.GenAI.Phi/Resource/Config/phi-3-medium-4k-instruct-config.json new file mode 100644 index 0000000000..4ded05bfd8 --- /dev/null +++ b/src/Microsoft.ML.GenAI.Phi/Resource/Config/phi-3-medium-4k-instruct-config.json @@ -0,0 +1,36 @@ +{ + "_name_or_path": "Phi-3-medium-4k-instruct", + "architectures": [ + "Phi3ForCausalLM" + ], + "attention_dropout": 0.0, + "auto_map": { + "AutoConfig": "configuration_phi3.Phi3Config", + "AutoModelForCausalLM": "modeling_phi3.Phi3ForCausalLM" + }, + "bos_token_id": 1, + "embd_pdrop": 0.0, + "eos_token_id": 32000, + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 17920, + "max_position_embeddings": 4096, + "model_type": "phi3", + "num_attention_heads": 40, + "num_hidden_layers": 40, + "num_key_value_heads": 10, + "original_max_position_embeddings": 4096, + "pad_token_id": 32000, + "resid_pdrop": 0.0, + "rms_norm_eps": 1e-05, + "rope_scaling": null, + "rope_theta": 10000.0, + "sliding_window": 2047, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.39.3", + "use_cache": true, + "attention_bias": false, + "vocab_size": 32064 +} diff --git a/src/Microsoft.ML.GenAI.Phi/Resource/Config/phi-3-mini-128k-instruct-config.json b/src/Microsoft.ML.GenAI.Phi/Resource/Config/phi-3-mini-128k-instruct-config.json new file mode 100644 index 0000000000..9aae2a82fc --- /dev/null +++ b/src/Microsoft.ML.GenAI.Phi/Resource/Config/phi-3-mini-128k-instruct-config.json @@ -0,0 +1,140 @@ +{ + "_name_or_path": "Phi-3-mini-128k-instruct", + "architectures": [ + "Phi3ForCausalLM" + ], + "attention_dropout": 0.0, + "auto_map": { + "AutoConfig": "configuration_phi3.Phi3Config", + "AutoModelForCausalLM": "modeling_phi3.Phi3ForCausalLM", + "AutoModelForSequenceClassification": "modeling_phi3.Phi3ForSequenceClassification", + "AutoModelForTokenClassification": "modeling_phi3.Phi3ForTokenClassification" + }, + "bos_token_id": 1, + "embd_pdrop": 0.0, + "eos_token_id": 32000, + "hidden_act": "silu", + "hidden_size": 3072, + "initializer_range": 0.02, + "intermediate_size": 8192, + "max_position_embeddings": 131072, + "model_type": "phi3", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 32, + "original_max_position_embeddings": 4096, + "pad_token_id": 32000, + "resid_pdrop": 0.0, + "rms_norm_eps": 1e-05, + "rope_scaling": { + "long_factor": [ + 1.0299999713897705, + 1.0499999523162842, + 1.0499999523162842, + 1.0799999237060547, + 1.2299998998641968, + 1.2299998998641968, + 1.2999999523162842, + 1.4499999284744263, + 1.5999999046325684, + 1.6499998569488525, + 1.8999998569488525, + 2.859999895095825, + 3.68999981880188, + 5.419999599456787, + 5.489999771118164, + 5.489999771118164, + 9.09000015258789, + 11.579999923706055, + 15.65999984741211, + 15.769999504089355, + 15.789999961853027, + 18.360000610351562, + 21.989999771118164, + 23.079999923706055, + 30.009998321533203, + 32.35000228881836, + 32.590003967285156, + 35.56000518798828, + 39.95000457763672, + 53.840003967285156, + 56.20000457763672, + 57.95000457763672, + 59.29000473022461, + 59.77000427246094, + 59.920005798339844, + 61.190006256103516, + 61.96000671386719, + 62.50000762939453, + 63.3700065612793, + 63.48000717163086, + 63.48000717163086, + 63.66000747680664, + 63.850006103515625, + 64.08000946044922, + 64.760009765625, + 64.80001068115234, + 64.81001281738281, + 64.81001281738281 + ], + "short_factor": [ + 1.05, + 1.05, + 1.05, + 1.1, + 1.1, + 1.1500000000000001, + 1.2000000000000002, + 1.2500000000000002, + 1.3000000000000003, + 1.3500000000000003, + 1.5000000000000004, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.0500000000000007, + 2.0500000000000007, + 2.0500000000000007, + 2.1000000000000005, + 2.1000000000000005, + 2.1000000000000005, + 2.1500000000000004, + 2.1500000000000004, + 2.3499999999999996, + 2.549999999999999, + 2.5999999999999988, + 2.5999999999999988, + 2.7499999999999982, + 2.849999999999998, + 2.849999999999998, + 2.9499999999999975 + ], + "type": "su" + }, + "rope_theta": 10000.0, + "sliding_window": 262144, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.39.3", + "use_cache": true, + "attention_bias": false, + "vocab_size": 32064 +} diff --git a/src/Microsoft.ML.GenAI.Phi/Resource/Config/phi-3-mini-4k-instruct-config.json b/src/Microsoft.ML.GenAI.Phi/Resource/Config/phi-3-mini-4k-instruct-config.json new file mode 100644 index 0000000000..92bb5b42eb --- /dev/null +++ b/src/Microsoft.ML.GenAI.Phi/Resource/Config/phi-3-mini-4k-instruct-config.json @@ -0,0 +1,36 @@ +{ + "_name_or_path": "Phi-3-mini-4k-instruct", + "architectures": [ + "Phi3ForCausalLM" + ], + "attention_dropout": 0.0, + "auto_map": { + "AutoConfig": "configuration_phi3.Phi3Config", + "AutoModelForCausalLM": "modeling_phi3.Phi3ForCausalLM" + }, + "bos_token_id": 1, + "embd_pdrop": 0.0, + "eos_token_id": 32000, + "hidden_act": "silu", + "hidden_size": 3072, + "initializer_range": 0.02, + "intermediate_size": 8192, + "max_position_embeddings": 4096, + "model_type": "phi3", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 32, + "original_max_position_embeddings": 4096, + "pad_token_id": 32000, + "resid_pdrop": 0.0, + "rms_norm_eps": 1e-05, + "rope_scaling": null, + "rope_theta": 10000.0, + "sliding_window": 2047, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.39.3", + "use_cache": true, + "attention_bias": false, + "vocab_size": 32064 +} diff --git a/src/Microsoft.ML.GenAI.Phi/Resource/README.md b/src/Microsoft.ML.GenAI.Phi/Resource/README.md new file mode 100644 index 0000000000..35f35bda1a --- /dev/null +++ b/src/Microsoft.ML.GenAI.Phi/Resource/README.md @@ -0,0 +1,8 @@ +## This folder includes the embedded resources for the GenAI.Phi project. + +### Configuration +- [phi-2-config.json](https://huggingface.co/microsoft/phi-2/blob/main/config.json): The phi-2 model configuration file. +- [phi-3-mini-128k-config.json](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json): The phi-3-mini-128k model configuration file. +- [phi-3-mini-4k-config.json](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json): The phi-3-mini-4k model configuration file. +- [phi-3-medium-4k-config.json](https://huggingface.co/microsoft/Phi-3-medium-4k-instruct/blob/main/config.json): The phi-3-medium-4k model configuration file. +- [phi-3-medium-128k-config.json](https://huggingface.co/microsoft/Phi-3-medium-128k-instruct/blob/main/config.json): The phi-3-medium-128k model configuration file.] \ No newline at end of file diff --git a/src/Microsoft.ML.GenAI.Phi/Utils.cs b/src/Microsoft.ML.GenAI.Phi/Utils.cs index c6f195d4b0..4591d94f14 100644 --- a/src/Microsoft.ML.GenAI.Phi/Utils.cs +++ b/src/Microsoft.ML.GenAI.Phi/Utils.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Reflection; using System.Text; using System.Threading.Tasks; using TorchSharp; @@ -13,8 +14,23 @@ using static TorchSharp.torch.nn; namespace Microsoft.ML.GenAI.Phi; -public static class Utils +internal static class Utils { + public static string GetEmbeddedResource(string resourceName) + { + // read file content from embedded resource + var assembly = Assembly.GetExecutingAssembly(); + var resourceStream = assembly.GetManifestResourceStream(resourceName); + + if (resourceStream == null) + { + throw new ArgumentException("Resource not found", nameof(resourceName)); + } + + using var reader = new System.IO.StreamReader(resourceStream); + return reader.ReadToEnd(); + } + public static Tensor ApplyRotaryEmbeddings(Tensor input, Tensor freqsComplex) { // Separate the last dimension pairs of two values, representing the real and imaginary parts of the complex number diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Mini128KLayerSizeTest.approved.txt b/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Mini128KLayerSizeTest.approved.txt index edb1e258bb..4f711239a7 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Mini128KLayerSizeTest.approved.txt +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Mini128KLayerSizeTest.approved.txt @@ -1,34 +1,34 @@ { - "model.layers.0": 216.01172, - "model.layers.1": 216.01172, - "model.layers.2": 216.01172, - "model.layers.3": 216.01172, - "model.layers.4": 216.01172, - "model.layers.5": 216.01172, - "model.layers.6": 216.01172, - "model.layers.7": 216.01172, - "model.layers.8": 216.01172, - "model.layers.9": 216.01172, - "model.layers.10": 216.01172, - "model.layers.11": 216.01172, - "model.layers.12": 216.01172, - "model.layers.13": 216.01172, - "model.layers.14": 216.01172, - "model.layers.15": 216.01172, - "model.layers.16": 216.01172, - "model.layers.17": 216.01172, - "model.layers.18": 216.01172, - "model.layers.19": 216.01172, - "model.layers.20": 216.01172, - "model.layers.21": 216.01172, - "model.layers.22": 216.01172, - "model.layers.23": 216.01172, - "model.layers.24": 216.01172, - "model.layers.25": 216.01172, - "model.layers.26": 216.01172, - "model.layers.27": 216.01172, - "model.layers.28": 216.01172, - "model.layers.29": 216.01172, - "model.layers.30": 216.01172, - "model.layers.31": 216.01172 + "model.layers.0": 216, + "model.layers.1": 216, + "model.layers.2": 216, + "model.layers.3": 216, + "model.layers.4": 216, + "model.layers.5": 216, + "model.layers.6": 216, + "model.layers.7": 216, + "model.layers.8": 216, + "model.layers.9": 216, + "model.layers.10": 216, + "model.layers.11": 216, + "model.layers.12": 216, + "model.layers.13": 216, + "model.layers.14": 216, + "model.layers.15": 216, + "model.layers.16": 216, + "model.layers.17": 216, + "model.layers.18": 216, + "model.layers.19": 216, + "model.layers.20": 216, + "model.layers.21": 216, + "model.layers.22": 216, + "model.layers.23": 216, + "model.layers.24": 216, + "model.layers.25": 216, + "model.layers.26": 216, + "model.layers.27": 216, + "model.layers.28": 216, + "model.layers.29": 216, + "model.layers.30": 216, + "model.layers.31": 216 } \ No newline at end of file diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Test.cs b/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Test.cs index 2666bd3175..44d1d74d7e 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Test.cs +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Test.cs @@ -26,6 +26,7 @@ public class Phi2Test : BaseTestClass { public Phi2Test(ITestOutputHelper output) : base(output) { + torch.set_default_device("meta"); } [Fact] @@ -33,42 +34,11 @@ public Phi2Test(ITestOutputHelper output) : base(output) [UseApprovalSubdirectory("Approvals")] public void LoadSafeTensorShapeTest() { - torch.set_default_device("meta"); - var modelWeightFolder = "C:\\Users\\xiaoyuz\\source\\repos\\phi-2"; - var configName = "config.json"; - var config = Path.Join(modelWeightFolder, configName); - var modelConfig = JsonSerializer.Deserialize(File.ReadAllText(config)) ?? throw new ArgumentNullException(nameof(config)); - var model = new Phi2ForCasualLM(modelConfig); + var model = new Phi2ForCasualLM(Phi2Config.Phi2); var stateDictStr = model.PeekShape(); Approvals.Verify(stateDictStr); } - //[Fact] - //[UseReporter(typeof(DiffReporter))] - //[UseApprovalSubdirectory("Approvals")] - //public async Task ForwardTest() - //{ - // // create dummy input id with 128 length and attention mask - // var device = "cuda"; - // var inputIds = torch.arange(128, dtype: ScalarType.Int64, device: device).unsqueeze(0); - // var attentionMask = torch.ones(1, 128, device: device); - // var modelWeightFolder = "C:\\Users\\xiaoyuz\\source\\repos\\phi-2"; - // var model = Phi2ForCasualLM.FromPretrained(modelWeightFolder, torchDtype: ScalarType.BFloat16, checkPointName: "model.safetensors.index.json", device: "cuda"); - // var input = new CasualLMModelInput(inputIds, attentionMask, past_key_values_length: 0); - // var output = model.forward(input); - // var outputTokenIds = output.last_hidden_state; - // var outputLogits = output.logits; - - // var outputTokenIdsStr = outputTokenIds.Peek("output"); - // var outputLogitsStr = outputLogits.Peek("logits"); - - // var sb = new StringBuilder(); - // sb.AppendLine(outputTokenIdsStr); - // sb.AppendLine(outputLogitsStr); - - // Approvals.Verify(sb.ToString()); - //} - [Fact] [UseReporter(typeof(DiffReporter))] [UseApprovalSubdirectory("Approvals")] diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs b/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs index ef13df2bbc..532e4ca601 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs @@ -29,10 +29,7 @@ public Phi3Tests(ITestOutputHelper output) : base(output) [UseApprovalSubdirectory("Approvals")] public void Phi3Mini4KShapeTest() { - var modelWeightFolder = "C:\\Users\\xiaoyuz\\source\\repos\\Phi-3-mini-4k-instruct"; - var config = Path.Join(modelWeightFolder, "config.json"); - var modelConfig = JsonSerializer.Deserialize(File.ReadAllText(config)) ?? throw new ArgumentNullException(nameof(config)); - var model = new Phi3ForCasualLM(modelConfig); + var model = new Phi3ForCasualLM(Phi3Config.Phi3Mini4kInstruct); var stateDictStr = model.PeekShape(); Approvals.Verify(stateDictStr); } @@ -42,10 +39,7 @@ public void Phi3Mini4KShapeTest() [UseApprovalSubdirectory("Approvals")] public void Phi3Medium4KShapeTest() { - var modelWeightFolder = "C:\\Users\\xiaoyuz\\source\\repos\\Phi-3-medium-4k-instruct"; - var config = Path.Join(modelWeightFolder, "config.json"); - var modelConfig = JsonSerializer.Deserialize(File.ReadAllText(config)) ?? throw new ArgumentNullException(nameof(config)); - var model = new Phi3ForCasualLM(modelConfig); + var model = new Phi3ForCasualLM(Phi3Config.Phi3Medium4kInstruct); var stateDictStr = model.PeekShape(); Approvals.Verify(stateDictStr); } @@ -56,10 +50,7 @@ public void Phi3Medium4KShapeTest() [UseApprovalSubdirectory("Approvals")] public void Phi3Medium128KShapeTest() { - var modelWeightFolder = "C:\\Users\\xiaoyuz\\source\\repos\\Phi-3-medium-128k-instruct"; - var config = Path.Join(modelWeightFolder, "config.json"); - var modelConfig = JsonSerializer.Deserialize(File.ReadAllText(config)) ?? throw new ArgumentNullException(nameof(config)); - var model = new Phi3ForCasualLM(modelConfig); + var model = new Phi3ForCasualLM(Phi3Config.Phi3Medium128kInstruct); var stateDictStr = model.PeekShape(); Approvals.Verify(stateDictStr); } @@ -69,10 +60,7 @@ public void Phi3Medium128KShapeTest() [UseApprovalSubdirectory("Approvals")] public void Phi3Mini128KShapeTest() { - var modelWeightFolder = "C:\\Users\\xiaoyuz\\source\\repos\\Phi-3-mini-128k-instruct"; - var config = Path.Join(modelWeightFolder, "config.json"); - var modelConfig = JsonSerializer.Deserialize(File.ReadAllText(config)) ?? throw new ArgumentNullException(nameof(config)); - var model = new Phi3ForCasualLM(modelConfig); + var model = new Phi3ForCasualLM(Phi3Config.Phi3Mini128kInstruct); var stateDictStr = model.PeekShape(); Approvals.Verify(stateDictStr); } @@ -82,16 +70,10 @@ public void Phi3Mini128KShapeTest() [UseApprovalSubdirectory("Approvals")] public void Phi3Mini128KLayerSizeTest() { - var dtype = ScalarType.BFloat16; - var modelWeightFolder = "C:\\Users\\xiaoyuz\\source\\repos\\Phi-3-mini-128k-instruct"; - var config = Path.Join(modelWeightFolder, "config.json"); - var modelConfig = JsonSerializer.Deserialize(File.ReadAllText(config)) ?? throw new ArgumentNullException(nameof(config)); - modelConfig.DType = dtype; - var model = new Phi3ForCasualLM(modelConfig); - + var model = new Phi3ForCasualLM(Phi3Config.Phi3Mini128kInstruct); var size = model.GetSizeForEachDynamicLayerInBytes(); // convert size to MB - var sizeInMB = size.ToDictionary(x => x.Key, x => x.Value * 1.0f / 1024 / 1024); + var sizeInMB = size.ToDictionary(x => x.Key, x => x.Value / 1024 / 1024); var json = JsonSerializer.Serialize(sizeInMB, new JsonSerializerOptions { WriteIndented = true }); Approvals.Verify(json); From 1faddbc31c54e09f6fcb263a7d7465c5ac209bd3 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Thu, 27 Jun 2024 16:55:35 -0700 Subject: [PATCH 06/41] add semantic kernel and autogen intergration --- eng/Versions.props | 1 + .../Extension/CausalLMPipelineExtension.cs | 50 ----------- .../Pipeline/CausalLMPipeline.cs | 60 ++++++++++++- .../Extension/SemanticKernelExtension.cs | 33 +++++++ .../Microsoft.ML.GenAI.Phi.csproj | 2 + .../Phi3/Phi3CausalLMAgent.cs | 84 ++++++++++++++++++ .../Phi3/Phi3CausalLMChatCompletionService.cs | 81 +++++++++++++++++ .../Phi3/Phi3CausalLMTextGenerationService.cs | 66 ++++++++++++++ .../AutoGenTests.cs | 48 ++++++++++ .../Microsoft.ML.GenAI.Phi.Tests.csproj | 2 + .../SemanticKernelTests.cs | 88 +++++++++++++++++++ 11 files changed, 464 insertions(+), 51 deletions(-) delete mode 100644 src/Microsoft.ML.GenAI.Core/Extension/CausalLMPipelineExtension.cs create mode 100644 src/Microsoft.ML.GenAI.Phi/Extension/SemanticKernelExtension.cs create mode 100644 src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMAgent.cs create mode 100644 src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMChatCompletionService.cs create mode 100644 src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMTextGenerationService.cs create mode 100644 test/Microsoft.ML.GenAI.Phi.Tests/AutoGenTests.cs create mode 100644 test/Microsoft.ML.GenAI.Phi.Tests/SemanticKernelTests.cs diff --git a/eng/Versions.props b/eng/Versions.props index b1d4979662..b48e6485bd 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -75,6 +75,7 @@ 1.2.0 5.4.7 + 4.20.70 0.13.1 6.0.26 8.0.1 diff --git a/src/Microsoft.ML.GenAI.Core/Extension/CausalLMPipelineExtension.cs b/src/Microsoft.ML.GenAI.Core/Extension/CausalLMPipelineExtension.cs deleted file mode 100644 index b9d6bb3f1c..0000000000 --- a/src/Microsoft.ML.GenAI.Core/Extension/CausalLMPipelineExtension.cs +++ /dev/null @@ -1,50 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; -using TorchSharp; -using static TorchSharp.torch; - -namespace Microsoft.ML.GenAI.Core.Extension; - -public static class CausalLMPipelineExtension -{ - public static string? Generate( - this CausalLMPipeline pipeline, - string prompt, - int maxLen = 128, - float temperature = 0.7f, - float topP = 0.9f, - string[]? stopSequences = null, - int eosId = 0, - string device = "cpu", - bool bos = true, - bool eos = false, - bool echo = false) - { - using var newScope = NewDisposeScope(); - var inputIds = pipeline.Tokenizer.EncodeToIds(prompt); - var inputTensor = torch.tensor(inputIds.ToArray(), dtype: ScalarType.Int64, device: device).unsqueeze(0); - var attentionMask = torch.ones_like(inputTensor); - - // set up stop token ids - // stop token ids: [[eosId], [stopSequence1], [stopSequence2], ...] - // when causal language model generates tokens, it will stop when it generates any token in stopSequences - List stopTokenIds = [[eosId]]; - if (stopSequences != null) - { - stopTokenIds.AddRange(stopSequences.Select(x => pipeline.Tokenizer.EncodeToIds(x).ToArray())); - } - - (var token, var _) = pipeline.Generate(inputTensor, attentionMask, temperature: temperature, maxLen: maxLen, topP: topP, stopTokenSequence: stopTokenIds.ToArray(), echo: echo); - - var tokenIds = token[0].to_type(ScalarType.Int32).data().ToArray(); - - return pipeline.Tokenizer.Decode(tokenIds); - } -} diff --git a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs index bc7d3c8e0d..63d2b29947 100644 --- a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs +++ b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs @@ -13,7 +13,7 @@ namespace Microsoft.ML.GenAI.Core; -public class CausalLMPipeline : CausalLMPipeline +public class CausalLMPipeline : CausalLMPipeline, ICausalLMPipeline where TTokenizer : Tokenizer where TModel : nn.Module { @@ -24,6 +24,25 @@ public CausalLMPipeline( : base(tokenizer, model, device) { } + + internal CausalLMPipeline() + : base() + { + } +} + +public interface ICausalLMPipeline + where TTokenizer : Tokenizer + where TModel : nn.Module +{ + string? Generate( + string prompt, + int maxLen = 128, + float temperature = 0.7F, + float topP = 0.9F, + string[]? stopSequences = null, + bool echo = false); + (Tensor, Tensor) Generate(Tensor inputIds, Tensor attentionMask, int[][] stopTokenSequence, float temperature = 0.7F, float topP = 0.9F, int maxLen = 128, bool echo = false); } public class CausalLMPipeline @@ -38,6 +57,16 @@ public CausalLMPipeline( this.Device = device; } + /// + /// For moq purpose + /// + protected private CausalLMPipeline() + { + this.Tokenizer = default!; + this.Model = default!; + this.Device = default!; + } + public Tokenizer Tokenizer { get; } public nn.Module Model { get; } @@ -131,6 +160,35 @@ public virtual ( } } + public virtual string? Generate( + string prompt, + int maxLen = 128, + float temperature = 0.7f, + float topP = 0.9f, + string[]? stopSequences = null, + bool echo = false) + { + using var newScope = NewDisposeScope(); + var inputIds = this.Tokenizer.EncodeToIds(prompt); + var inputTensor = torch.tensor(inputIds.ToArray(), dtype: ScalarType.Int64, device: this.Device).unsqueeze(0); + var attentionMask = torch.ones_like(inputTensor); + + // set up stop token ids + // stop token ids: [[eosId], [stopSequence1], [stopSequence2], ...] + // when causal language model generates tokens, it will stop when it generates any token in stopSequences + List stopTokenIds = [[]]; + if (stopSequences != null) + { + stopTokenIds.AddRange(stopSequences.Select(x => this.Tokenizer.EncodeToIds(x).ToArray())); + } + + (var token, var _) = this.Generate(inputTensor, attentionMask, temperature: temperature, maxLen: maxLen, topP: topP, stopTokenSequence: stopTokenIds.ToArray(), echo: echo); + + var tokenIds = token[0].to_type(ScalarType.Int32).data().ToArray(); + + return this.Tokenizer.Decode(tokenIds); + } + protected torch.Tensor SampleTopP(torch.Tensor logits, float topP) { (var probsSort, var probsIndex) = torch.sort(logits, dim: -1, descending: true); diff --git a/src/Microsoft.ML.GenAI.Phi/Extension/SemanticKernelExtension.cs b/src/Microsoft.ML.GenAI.Phi/Extension/SemanticKernelExtension.cs new file mode 100644 index 0000000000..c2ef497d64 --- /dev/null +++ b/src/Microsoft.ML.GenAI.Phi/Extension/SemanticKernelExtension.cs @@ -0,0 +1,33 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.ML.GenAI.Core; +using Microsoft.ML.Tokenizers; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.TextGeneration; + +namespace Microsoft.ML.GenAI.Phi.Extension; + +public static class SemanticKernelExtension +{ + public static IKernelBuilder AddPhi3AsChatCompletion( + this IKernelBuilder builder, + ICausalLMPipeline pipeline) + { + builder.Services.AddSingleton(new Phi3CausalLMChatCompletionService(pipeline)); + + return builder; + } + + public static IKernelBuilder AddPhi3AsTextGeneration( + this IKernelBuilder builder, + ICausalLMPipeline pipeline) + { + builder.Services.AddSingleton(new Phi3CausalLMTextGenerationService(pipeline)); + + return builder; + } +} diff --git a/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj b/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj index 59f77e73c8..19b1c156f2 100644 --- a/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj +++ b/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj @@ -7,12 +7,14 @@ + + diff --git a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMAgent.cs b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMAgent.cs new file mode 100644 index 0000000000..c4ff1a8b23 --- /dev/null +++ b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMAgent.cs @@ -0,0 +1,84 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using AutoGen.Core; +using Microsoft.ML.GenAI.Core; +using Microsoft.ML.Tokenizers; + +namespace Microsoft.ML.GenAI.Phi; + +public class Phi3Agent : IAgent +{ + private const char Newline = '\n'; + private readonly ICausalLMPipeline _pipeline; + private readonly string? _systemMessage; + + public Phi3Agent( + ICausalLMPipeline pipeline, + string name, + string? systemMessage = "you are a helpful assistant") + { + this.Name = name; + this._pipeline = pipeline; + this._systemMessage = systemMessage; + } + + public string Name { get; } + + public Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) + { + var availableRoles = new[] { Role.System, Role.User, Role.Assistant }; + if (messages.Any(m => m.GetContent() is null)) + { + throw new InvalidOperationException("Please provide a message with content."); + } + + if (messages.Any(m => m.GetRole() is null || availableRoles.Contains(m.GetRole()!.Value) == false)) + { + throw new InvalidOperationException("Please provide a message with a valid role. The valid roles are System, User, and Assistant."); + } + + // construct template based on instruction from + // https://huggingface.co/microsoft/Phi-3-mini-128k-instruct#chat-format + + var sb = new StringBuilder(); + if (_systemMessage is not null) + { + sb.Append($"<|system|>{Newline}{_systemMessage}<|end|>{Newline}"); + } + foreach (var message in messages) + { + var role = message.GetRole()!.Value; + var content = message.GetContent()!; + sb.Append(message switch + { + _ when message.GetRole() == Role.System => $"<|system|>{Newline}{content}<|end|>{Newline}", + _ when message.GetRole() == Role.User => $"<|user|>{Newline}{content}<|end|>{Newline}", + _ when message.GetRole() == Role.Assistant => $"<|assistant|>{Newline}{content}<|end|>{Newline}", + _ => throw new InvalidOperationException("Invalid role.") + }); + } + + sb.Append("<|assistant|>"); + var input = sb.ToString(); + + var maxLen = options?.MaxToken ?? 1024; + var temperature = options?.Temperature ?? 0.7f; + var stopTokenSequence = options?.StopSequence ?? []; + stopTokenSequence = stopTokenSequence.Append("<|end|>").ToArray(); + + var output = _pipeline.Generate( + input, + maxLen: maxLen, + temperature: temperature, + stopSequences: stopTokenSequence) ?? throw new InvalidOperationException("Failed to generate a reply."); + + return Task.FromResult(new TextMessage(Role.Assistant, output, from: this.Name)); + } +} diff --git a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMChatCompletionService.cs b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMChatCompletionService.cs new file mode 100644 index 0000000000..4888d7a0f8 --- /dev/null +++ b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMChatCompletionService.cs @@ -0,0 +1,81 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Runtime.CompilerServices; +using System.Text; +using Microsoft.ML.GenAI.Core; +using Microsoft.ML.Tokenizers; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.TextGeneration; + +namespace Microsoft.ML.GenAI.Phi; + +public class Phi3CausalLMChatCompletionService : IChatCompletionService +{ + private readonly ICausalLMPipeline _pipeline; + private readonly Phi3CausalLMTextGenerationService _textGenerationService; + private const char NewLine = '\n'; // has to be \n, \r\n will cause wanky result. + + public Phi3CausalLMChatCompletionService(ICausalLMPipeline pipeline) + { + _pipeline = pipeline; + _textGenerationService = new Phi3CausalLMTextGenerationService(pipeline); + } + + public IReadOnlyDictionary Attributes => _textGenerationService.Attributes; + + public async Task> GetChatMessageContentsAsync( + ChatHistory chatHistory, + PromptExecutionSettings? executionSettings = null, + Kernel? kernel = null, + CancellationToken cancellationToken = default) + { + // build prompt from chat history + var sb = new StringBuilder(); + + foreach (var message in chatHistory) + { + foreach (var item in message.Items) + { + if (item is not TextContent textContent) + { + throw new NotSupportedException($"Only text content is supported, but got {item.GetType().Name}"); + } + + var prompt = message.Role switch + { + _ when message.Role == AuthorRole.System => $"<|system|>{NewLine}{textContent}<|end|>{NewLine}", + _ when message.Role == AuthorRole.User => $"<|user|>{NewLine}{textContent}<|end|>{NewLine}", + _ when message.Role == AuthorRole.Assistant => $"<|assistant|>{NewLine}{textContent}<|end|>{NewLine}", + _ => throw new NotSupportedException($"Unsupported role {message.Role}") + }; + + sb.Append(prompt); + } + } + + sb.Append("<|assistant|>"); + var reply = await _textGenerationService.GetTextContentAsync(sb.ToString(), executionSettings, kernel, cancellationToken); + return [new ChatMessageContent(AuthorRole.Assistant, reply.Text)]; + } + + public async IAsyncEnumerable GetStreamingChatMessageContentsAsync( + ChatHistory chatHistory, + PromptExecutionSettings? executionSettings = null, + Kernel? kernel = null, + [EnumeratorCancellation] + CancellationToken cancellationToken = default) + { + // CausalLMPipeline doesn't support streaming output yet + // here we simply implement this api using the synchronous version + + var response = await GetChatMessageContentsAsync(chatHistory, executionSettings, kernel, cancellationToken); + + foreach (var item in response) + { + yield return new StreamingChatMessageContent(item.Role, item.Content); + } + } +} diff --git a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMTextGenerationService.cs b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMTextGenerationService.cs new file mode 100644 index 0000000000..a94c91fe0b --- /dev/null +++ b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMTextGenerationService.cs @@ -0,0 +1,66 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Runtime.CompilerServices; +using Microsoft.ML.GenAI.Core; +using Microsoft.ML.Tokenizers; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.TextGeneration; + +namespace Microsoft.ML.GenAI.Phi; + +public class Phi3CausalLMTextGenerationService : ITextGenerationService +{ + private readonly ICausalLMPipeline _pipeline; + + public Phi3CausalLMTextGenerationService(ICausalLMPipeline pipeline) + { + _pipeline = pipeline; + } + + public IReadOnlyDictionary Attributes => new Dictionary() + { + { "temperature", null }, + { "max_token", null }, + { "stop_token_sequence", null }, + { "top_p", null }, + }; + + public Task> GetTextContentsAsync(string prompt, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) + { + var temperature = executionSettings?.ExtensionData?["temperature"] as float? ?? 0.7f; + var maxToken = executionSettings?.ExtensionData?["max_token"] as int? ?? 100; + var stopTokenSequence = executionSettings?.ExtensionData?["stop_token_sequence"] as string[] ?? Array.Empty(); + var topP = executionSettings?.ExtensionData?["top_p"] as float? ?? 0.9f; + stopTokenSequence.Append("<|end|>"); + + var response = _pipeline.Generate( + prompt, + maxToken, + temperature, + stopSequences: stopTokenSequence, + topP: topP); + + return Task.FromResult>([new TextContent(response)]); + } + + public async IAsyncEnumerable GetStreamingTextContentsAsync( + string prompt, + PromptExecutionSettings? + executionSettings = null, + Kernel? kernel = null, + [EnumeratorCancellation] + CancellationToken cancellationToken = default) + { + // CausalLMPipeline doesn't support streaming output yet + // here we simply implement this api using the synchronous version + + var response = await GetTextContentsAsync(prompt, executionSettings, kernel, cancellationToken); + + foreach (var item in response) + { + yield return new StreamingTextContent(item.Text); + } + } +} diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/AutoGenTests.cs b/test/Microsoft.ML.GenAI.Phi.Tests/AutoGenTests.cs new file mode 100644 index 0000000000..21ae811fda --- /dev/null +++ b/test/Microsoft.ML.GenAI.Phi.Tests/AutoGenTests.cs @@ -0,0 +1,48 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using AutoGen.Core; +using FluentAssertions; +using Microsoft.ML.GenAI.Core; +using Microsoft.ML.TestFramework; +using Microsoft.ML.Tokenizers; +using Moq; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.GenAI.Phi.Tests; + +public class AutoGenTests : BaseTestClass +{ + public AutoGenTests(ITestOutputHelper helper) + : base(helper) + { + } + + [Fact] + public async Task ItGenerateTextReply() + { + var pipeline = Mock.Of>(); + // mock generate api + Mock.Get(pipeline).Setup(p => p.Generate( + It.IsAny(), // prompt + It.IsAny(), // max length + It.IsAny(), // temperature + It.IsAny(), // top_p + It.IsAny(), // stop sequence + It.IsAny())) // echo + .Callback((string prompt, int maxLen, float temperature, float topP, string[] stopSequences, bool echo) => + { + // check prompt + prompt.Should().Be("<|system|>\nyou are a helpful assistant<|end|>\n<|user|>\nhey<|end|>\n<|assistant|>"); + }) + .Returns((string prompt, int maxLen, float temperature, float topP, string[] stopSequences, bool echo) => "hello"); + + var agent = new Phi3Agent(pipeline, "assistant"); + var reply = await agent.SendAsync("hey"); + + reply.GetContent().Should().Be("hello"); + reply.From.Should().Be(agent.Name); + } +} diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj b/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj index ed7f2c6342..5d1d22f4f0 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj @@ -16,6 +16,8 @@ + + diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/SemanticKernelTests.cs b/test/Microsoft.ML.GenAI.Phi.Tests/SemanticKernelTests.cs new file mode 100644 index 0000000000..ec61de4601 --- /dev/null +++ b/test/Microsoft.ML.GenAI.Phi.Tests/SemanticKernelTests.cs @@ -0,0 +1,88 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using FluentAssertions; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.ML.GenAI.Core; +using Microsoft.ML.GenAI.Phi.Extension; +using Microsoft.ML.TestFramework; +using Microsoft.ML.Tokenizers; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Moq; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.GenAI.Phi.Tests; + +public class SemanticKernelTests : BaseTestClass +{ + public SemanticKernelTests(ITestOutputHelper helper) + : base(helper) + { + } + + [Fact] + public async Task ItAddPhi3CausalLMChatCompletionServiceTestAsync() + { + var pipeline = Mock.Of>(MockBehavior.Loose); + // mock generate api + Mock.Get(pipeline).Setup(p => p.Generate( + It.IsAny(), // prompt + It.IsAny(), // max length + It.IsAny(), // temperature + It.IsAny(), // top_p + It.IsAny(), // stop sequence + It.IsAny())) // echo + .Callback((string prompt, int maxLen, float temperature, float topP, string[] stopSequences, bool echo) => + { + // check prompt + prompt.Should().Be("<|system|>\nyou are a helpful assistant<|end|>\n<|user|>\nhey<|end|>\n<|assistant|>"); + }) + .Returns((string prompt, int maxLen, float temperature, float topP, string[] stopSequences, bool echo) => "hello"); + + var kernel = Kernel.CreateBuilder() + .AddPhi3AsChatCompletion(pipeline) + .Build(); + + var chatService = kernel.Services.GetRequiredService(); + + var chatHistory = new ChatHistory(); + chatHistory.AddSystemMessage("you are a helpful assistant"); + chatHistory.AddUserMessage("hey"); + var responses = await chatService.GetChatMessageContentsAsync(chatHistory); + responses.Count().Should().Be(1); + var response = responses.First(); + response.Role.Should().Be(AuthorRole.Assistant); + response.Items.Count().Should().Be(1); + var textContent = response.Items.First() as TextContent; + textContent!.Text.Should().Be("hello"); + } + + [Fact] + public async Task ItAddPhi3CausalLMTextGenerationServiceTestAsync() + { + var pipeline = Mock.Of>(MockBehavior.Loose); + // mock generate api + Mock.Get(pipeline).Setup(p => p.Generate( + It.IsAny(), // prompt + It.IsAny(), // max length + It.IsAny(), // temperature + It.IsAny(), // top_p + It.IsAny(), // stop sequence + It.IsAny())) // echo + .Callback((string prompt, int maxLen, float temperature, float topP, string[] stopSequences, bool echo) => + { + // check prompt + prompt.Should().Be("test"); + }) + .Returns((string prompt, int maxLen, float temperature, float topP, string[] stopSequences, bool echo) => "hello"); + + var kernel = Kernel.CreateBuilder() + .AddPhi3AsTextGeneration(pipeline) + .Build(); + + var response = await kernel.InvokePromptAsync("test"); + } +} From 2802ed32eee84dd1945da9e3f8ee82e32f3e5ea4 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Thu, 27 Jun 2024 17:38:09 -0700 Subject: [PATCH 07/41] update --- .../Microsoft.ML.GenAI.Core.csproj | 3 --- src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs | 9 ++++++--- src/Microsoft.ML.GenAI.Phi/Phi3/Phi3Tokenzier.cs | 4 ---- test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs | 8 ++++++++ 4 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj b/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj index ad4520e133..17d1b45dbb 100644 --- a/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj +++ b/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj @@ -10,9 +10,6 @@ - - - diff --git a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs index 63d2b29947..c60122fa7b 100644 --- a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs +++ b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. @@ -7,6 +7,7 @@ using System.Linq; using System.Text; using System.Threading.Tasks; +using Microsoft.ML.GenAI.Core.Extension; using Microsoft.ML.Tokenizers; using TorchSharp; using static TorchSharp.torch; @@ -31,7 +32,7 @@ internal CausalLMPipeline() } } -public interface ICausalLMPipeline +public interface ICausalLMPipeline where TTokenizer : Tokenizer where TModel : nn.Module { @@ -179,9 +180,11 @@ public virtual ( List stopTokenIds = [[]]; if (stopSequences != null) { - stopTokenIds.AddRange(stopSequences.Select(x => this.Tokenizer.EncodeToIds(x).ToArray())); + stopTokenIds.AddRange(stopSequences.Select(x => this.Tokenizer.EncodeToIds(x, false, false).ToArray())); } + stopTokenIds = stopTokenIds.Where(ids => ids.Count() > 0).ToList(); + (var token, var _) = this.Generate(inputTensor, attentionMask, temperature: temperature, maxLen: maxLen, topP: topP, stopTokenSequence: stopTokenIds.ToArray(), echo: echo); var tokenIds = token[0].to_type(ScalarType.Int32).data().ToArray(); diff --git a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3Tokenzier.cs b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3Tokenzier.cs index dada1fa64a..e204def0ec 100644 --- a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3Tokenzier.cs +++ b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3Tokenzier.cs @@ -86,10 +86,6 @@ public override IReadOnlyList Encode(ReadOnlySpan text, out string? public override IReadOnlyList EncodeToIds(ReadOnlySpan text, bool considerPreTokenization = true, bool considerNormalization = true) { var input = text.ToString(); - if (this._addPrecedingSpace) - { - //input = " " + input; - } // step 1: // replace all special tokens to var re = new Regex($"{SystemSymbol.Replace("|", "\\|")}|{UserSymbol.Replace("|", "\\|")}|{AssistantSymbol.Replace("|", "\\|")}|{EndSymbol.Replace("|", "\\|")}"); diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs b/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs index 532e4ca601..d83777a2f7 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs @@ -88,12 +88,20 @@ public void TokenizerTest() var tokenizer = Phi3Tokenizer.FromPretrained(modelWeightFolder); tokenizer.BosId.Should().Be(1); tokenizer.EosId.Should().Be(2); + + // test <|end|> + var endIds = tokenizer.EncodeToIds("<|end|>", considerPreTokenization: false, considerNormalization: false); + endIds.Should().BeEquivalentTo(new int[] { 32007 }); + var messages = new string[] { "Can you provide ways to eat combinations of bananas and dragonfruits?", "Sure! Here are some ways to eat bananas and dragonfruits together: 1. Banana and dragonfruit smoothie: Blend bananas and dragonfruits together with some milk and honey. 2. Banana and dragonfruit salad: Mix sliced bananas and dragonfruits together with some lemon juice and honey.", "What about solving an 2x + 3 = 7 equation?", "\nCount to 3\n", + "<|user|>", + "<|end|>", + "<|assistant|>", "<|user|>\nCount to 3<|end|>\n<|assistant|>", }; var sb = new StringBuilder(); From c8ab578c994a637563f6f0fa3254a6518207fb87 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Fri, 28 Jun 2024 01:23:03 -0700 Subject: [PATCH 08/41] add Microsoft.ML.GenAI.Sample --- Microsoft.ML.sln | 13 +- .../Microsoft.ML.GenAI.Samples.csproj | 20 ++ .../Phi3Mini/AutoGenSample.cs | 39 +++ .../Phi3Mini/SemanticKernelSample.cs | 66 +++++ .../Phi3Mini/Utils.cs | 48 ++++ .../Microsoft.ML.GenAI.Samples/Program.cs | 4 + .../Extension/ModuleExtension.cs | 52 +--- .../Pipeline/CausalLMPipeline.cs | 246 +++++++++++------- .../Microsoft.ML.GenAI.Phi.csproj | 1 + .../Phi3/Phi3CausalLMAgent.cs | 56 +++- .../Phi3/Phi3CausalLMChatCompletionService.cs | 42 +-- .../Phi3/Phi3CausalLMTextGenerationService.cs | 29 ++- .../Phi3/Phi3Tokenzier.cs | 64 ++--- .../Phi3Tests.TokenizerTest.approved.txt | 3 + .../AutoGenTests.cs | 7 +- .../Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs | 8 +- .../SemanticKernelTests.cs | 18 +- 17 files changed, 472 insertions(+), 244 deletions(-) create mode 100644 docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj create mode 100644 docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/AutoGenSample.cs create mode 100644 docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/SemanticKernelSample.cs create mode 100644 docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs create mode 100644 docs/samples/Microsoft.ML.GenAI.Samples/Program.cs diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln index 824f88dd5f..c30f67410a 100644 --- a/Microsoft.ML.sln +++ b/Microsoft.ML.sln @@ -180,7 +180,9 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.GenAI.Core", " EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.GenAI.Phi", "src\Microsoft.ML.GenAI.Phi\Microsoft.ML.GenAI.Phi.csproj", "{694BF884-B2E4-4E1C-9342-0564BAAC4575}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.ML.GenAI.Phi.Tests", "test\Microsoft.ML.GenAI.Phi.Tests\Microsoft.ML.GenAI.Phi.Tests.csproj", "{867FFC34-DFA7-400F-B9BB-85158326CE08}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.GenAI.Phi.Tests", "test\Microsoft.ML.GenAI.Phi.Tests\Microsoft.ML.GenAI.Phi.Tests.csproj", "{867FFC34-DFA7-400F-B9BB-85158326CE08}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.ML.GenAI.Samples", "docs\samples\Microsoft.ML.GenAI.Samples\Microsoft.ML.GenAI.Samples.csproj", "{1D4AD9A3-19AF-432B-889D-A63FE6D7BD47}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution @@ -858,6 +860,14 @@ Global {867FFC34-DFA7-400F-B9BB-85158326CE08}.Release|Any CPU.Build.0 = Release|Any CPU {867FFC34-DFA7-400F-B9BB-85158326CE08}.Release|x64.ActiveCfg = Release|Any CPU {867FFC34-DFA7-400F-B9BB-85158326CE08}.Release|x64.Build.0 = Release|Any CPU + {1D4AD9A3-19AF-432B-889D-A63FE6D7BD47}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {1D4AD9A3-19AF-432B-889D-A63FE6D7BD47}.Debug|Any CPU.Build.0 = Debug|Any CPU + {1D4AD9A3-19AF-432B-889D-A63FE6D7BD47}.Debug|x64.ActiveCfg = Debug|Any CPU + {1D4AD9A3-19AF-432B-889D-A63FE6D7BD47}.Debug|x64.Build.0 = Debug|Any CPU + {1D4AD9A3-19AF-432B-889D-A63FE6D7BD47}.Release|Any CPU.ActiveCfg = Release|Any CPU + {1D4AD9A3-19AF-432B-889D-A63FE6D7BD47}.Release|Any CPU.Build.0 = Release|Any CPU + {1D4AD9A3-19AF-432B-889D-A63FE6D7BD47}.Release|x64.ActiveCfg = Release|Any CPU + {1D4AD9A3-19AF-432B-889D-A63FE6D7BD47}.Release|x64.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -947,6 +957,7 @@ Global {DB2CA055-8ABD-4E3E-8089-5B64C3415E85} = {09EADF06-BE25-4228-AB53-95AE3E15B530} {694BF884-B2E4-4E1C-9342-0564BAAC4575} = {09EADF06-BE25-4228-AB53-95AE3E15B530} {867FFC34-DFA7-400F-B9BB-85158326CE08} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} + {1D4AD9A3-19AF-432B-889D-A63FE6D7BD47} = {DA452A53-2E94-4433-B08C-041EDEC729E6} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D} diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj b/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj new file mode 100644 index 0000000000..e522cff52e --- /dev/null +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj @@ -0,0 +1,20 @@ + + + + Exe + net8.0 + enable + enable + + + + + + + + + + + + + diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/AutoGenSample.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/AutoGenSample.cs new file mode 100644 index 0000000000..be26cc035e --- /dev/null +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/AutoGenSample.cs @@ -0,0 +1,39 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using AutoGen.Core; +using Microsoft.ML.GenAI.Phi; +using static TorchSharp.torch; +using TorchSharp; +using Microsoft.ML.GenAI.Core; +using Microsoft.ML.GenAI.Core.Extension; + +namespace Microsoft.ML.GenAI.Samples.Phi3Mini; + +public class AutoGenSample +{ + public static async Task RunAsync() + { + var device = "cuda"; + if (device == "cuda") + { + torch.InitializeDeviceType(DeviceType.CUDA); + } + + var defaultType = ScalarType.Float16; + torch.manual_seed(1); + torch.set_default_dtype(defaultType); + var weightFolder = @"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct"; + var pipeline = Utils.LoadPhi3Mini4KFromFolder(weightFolder, device); + + // agent + var agent = new Phi3Agent(pipeline, "assistant") + .RegisterPrintMessage(); + var question = @"write a C# program to calculate the factorial of a number"; + + // chat with the assistant + await agent.SendAsync(question); + } +} diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/SemanticKernelSample.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/SemanticKernelSample.cs new file mode 100644 index 0000000000..0a3016dcde --- /dev/null +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/SemanticKernelSample.cs @@ -0,0 +1,66 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.ML.GenAI.Phi; +using static TorchSharp.torch; +using TorchSharp; +using Microsoft.SemanticKernel; +using Microsoft.ML.GenAI.Phi.Extension; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace Microsoft.ML.GenAI.Samples.Phi3Mini; + +public class SemanticKernelSample +{ + public static async Task RunChatCompletionSample() + { + var device = "cuda"; + if (device == "cuda") + { + torch.InitializeDeviceType(DeviceType.CUDA); + } + + var defaultType = ScalarType.Float16; + torch.manual_seed(1); + torch.set_default_dtype(defaultType); + var weightFolder = @"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct"; + var pipeline = Utils.LoadPhi3Mini4KFromFolder(weightFolder, device); + + + var kernel = Kernel.CreateBuilder() + .AddPhi3AsChatCompletion(pipeline) + .Build(); + var chatService = kernel.GetRequiredService(); + var chatHistory = new ChatHistory(); + chatHistory.AddSystemMessage("you are a helpful assistant"); + chatHistory.AddUserMessage("write a C# program to calculate the factorial of a number"); + + var response = await chatService.GetChatMessageContentAsync(chatHistory); + Console.WriteLine(response); + } + + public static async Task RunTextGenerationSample() + { + var device = "cuda"; + if (device == "cuda") + { + torch.InitializeDeviceType(DeviceType.CUDA); + } + + var defaultType = ScalarType.Float16; + torch.manual_seed(1); + torch.set_default_dtype(defaultType); + var weightFolder = @"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct"; + var pipeline = Utils.LoadPhi3Mini4KFromFolder(weightFolder, device); + + + var kernel = Kernel.CreateBuilder() + .AddPhi3AsTextGeneration(pipeline) + .Build(); + + var response = await kernel.InvokePromptAsync("write a C# program to calculate the factorial of a number"); + Console.WriteLine(response); + } +} diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs new file mode 100644 index 0000000000..33769d6330 --- /dev/null +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs @@ -0,0 +1,48 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.ML.GenAI.Core; +using Microsoft.ML.GenAI.Phi; +using Tensorboard; +using static TorchSharp.torch; +using TorchSharp; +using Microsoft.ML.GenAI.Core.Extension; + +namespace Microsoft.ML.GenAI.Samples.Phi3Mini; + +internal static class Utils +{ + public static CausalLMPipeline LoadPhi3Mini4KFromFolder( + string weightFolder, + string device = "cuda", + int modelSizeOnCudaInGB = 16, + int modelSizeOnMemoryInGB = 64, + int modelSizeOnDiskInGB = 200) + { + var defaultType = ScalarType.Float16; + Console.WriteLine("Loading Phi3 from huggingface model weight folder"); + var timer = System.Diagnostics.Stopwatch.StartNew(); + var model = Phi3ForCasualLM.FromPretrained(weightFolder, device: device, torchDtype: defaultType, checkPointName: "model.safetensors.index.json"); + var tokenizer = Phi3Tokenizer.FromPretrained(weightFolder); + var deviceSizeMap = new Dictionary + { + ["cuda:0"] = modelSizeOnCudaInGB * 1024 * 1024 * 1024, + ["cpu"] = modelSizeOnMemoryInGB * 1024 * 1024 * 1024, + ["disk"] = modelSizeOnDiskInGB * 1024 * 1024 * 1024, + }; + + var deviceMap = model.InferDeviceMapForEachLayer( + devices: ["cuda:0", "cpu", "disk"], + deviceSizeMapInByte: deviceSizeMap); + + model = model.ToDynamicLoadingModel(deviceMap, "cuda:0"); + + var pipeline = new CausalLMPipeline(tokenizer, model, device); + timer.Stop(); + Console.WriteLine($"Phi3 loaded in {timer.ElapsedMilliseconds / 1000} s"); + + return pipeline; + } +} diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs new file mode 100644 index 0000000000..63e18e5f42 --- /dev/null +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs @@ -0,0 +1,4 @@ +// See https://aka.ms/new-console-template for more information +using Microsoft.ML.GenAI.Samples.Phi3Mini; + +await SemanticKernelSample.RunTextGenerationSample(); diff --git a/src/Microsoft.ML.GenAI.Core/Extension/ModuleExtension.cs b/src/Microsoft.ML.GenAI.Core/Extension/ModuleExtension.cs index 6395ffd3fd..a3fd98b1f3 100644 --- a/src/Microsoft.ML.GenAI.Core/Extension/ModuleExtension.cs +++ b/src/Microsoft.ML.GenAI.Core/Extension/ModuleExtension.cs @@ -4,17 +4,14 @@ using System; using System.Collections.Generic; -using System.IO; using System.Linq; using System.Text; -using Microsoft.ML.GenAI.Core; -using Microsoft.ML.GenAI.Core.Extension; using TorchSharp; using static TorchSharp.torch; namespace Microsoft.ML.GenAI.Core.Extension; -internal static class ModuleExtension +public static class ModuleExtension { public static long GetSizeInBytes(this nn.Module model) { @@ -159,7 +156,7 @@ public static Dictionary InferDeviceMapForEachLayer( return deviceMap; } - public static string Peek(this nn.Module model) + internal static string Peek(this nn.Module model) { var sb = new StringBuilder(); var stateDict = model.state_dict(); @@ -177,7 +174,7 @@ public static string Peek(this nn.Module model) return res; } - public static string PeekShape(this nn.Module model) + internal static string PeekShape(this nn.Module model) { var sb = new StringBuilder(); var stateDict = model.state_dict(); @@ -195,47 +192,4 @@ public static string PeekShape(this nn.Module model) return res; } - - public static void LoadStateDict(this Dictionary dict, string location) - { - using FileStream stream = File.OpenRead(location); - using BinaryReader reader = new BinaryReader(stream); - var num = reader.Decode(); - for (int i = 0; i < num; i++) - { - var key = reader.ReadString(); - Tensor tensor = dict[key]; - - var originalDevice = tensor.device; - var originalType = tensor.dtype; - if (tensor.dtype == ScalarType.BFloat16) - { - tensor = tensor.to_type(ScalarType.Float32); - } - - TensorExtensionMethods.Load(ref tensor!, reader, skip: false); - - tensor = tensor!.to_type(originalType); - dict[key] = tensor; - } - } - - public static long Decode(this BinaryReader reader) - { - long num = 0L; - int num2 = 0; - while (true) - { - long num3 = reader.ReadByte(); - num += (num3 & 0x7F) << num2 * 7; - if ((num3 & 0x80) == 0L) - { - break; - } - - num2++; - } - - return num; - } } diff --git a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs index c60122fa7b..417d1d87a2 100644 --- a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs +++ b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. @@ -14,6 +14,48 @@ namespace Microsoft.ML.GenAI.Core; +public interface ICausalLMPipeline : ICausalLMPipeline + where TTokenizer : Tokenizer + where TModel : nn.Module +{ + TTokenizer Tokenizer { get; } + + TModel Model { get; } +} + +public interface ICausalLMPipeline +{ + string Generate( + string prompt, + int maxLen = 128, + float temperature = 0.7F, + float topP = 0.9F, + string[]? stopSequences = null); + + IEnumerable GenerateStreaming( + string prompt, + int maxLen = 128, + float temperature = 0.7F, + float topP = 0.9F, + string[]? stopSequences = null); + + (Tensor, Tensor) Generate( + Tensor inputIds, + Tensor attentionMask, + int[][] stopTokenSequence, + float temperature = 0.7F, + float topP = 0.9F, + int maxLen = 128); + + IEnumerable<(Tensor, Tensor)> GenerateStreaming( + Tensor inputIds, + Tensor attentionMask, + int[][] stopTokenSequence, + float temperature = 0.7F, + float topP = 0.9F, + int maxLen = 128); +} + public class CausalLMPipeline : CausalLMPipeline, ICausalLMPipeline where TTokenizer : Tokenizer where TModel : nn.Module @@ -26,27 +68,12 @@ public CausalLMPipeline( { } - internal CausalLMPipeline() - : base() - { - } -} + public new TTokenizer Tokenizer { get => (TTokenizer)base.Tokenizer; } -public interface ICausalLMPipeline - where TTokenizer : Tokenizer - where TModel : nn.Module -{ - string? Generate( - string prompt, - int maxLen = 128, - float temperature = 0.7F, - float topP = 0.9F, - string[]? stopSequences = null, - bool echo = false); - (Tensor, Tensor) Generate(Tensor inputIds, Tensor attentionMask, int[][] stopTokenSequence, float temperature = 0.7F, float topP = 0.9F, int maxLen = 128, bool echo = false); + public new TModel Model { get => (TModel)base.Model; } } -public class CausalLMPipeline +public class CausalLMPipeline : ICausalLMPipeline { public CausalLMPipeline( Tokenizer tokenizer, @@ -74,106 +101,134 @@ protected private CausalLMPipeline() public Device Device { get; } - public virtual ( - Tensor, // output token ids [batch_size, sequence_length] - Tensor // output logits [batch_size, sequence_length, vocab_size] - ) Generate( - Tensor inputIds, // input token ids [batch_size, sequence_length] - Tensor attentionMask, // attention mask [batch_size, sequence_length] + public IEnumerable<( + Tensor, // output token ids [batch_size, 1] + Tensor // output logits [batch_size, 1, vocab_size] + )> GenerateStreaming( + Tensor inputIds, + Tensor attentionMask, int[][] stopTokenSequence, - float temperature = 0.7f, - float topP = 0.9f, - int maxLen = 128, - bool echo = false) + float temperature = 0.7F, + float topP = 0.9F, + int maxLen = 128) { - using var newScope = NewDisposeScope(); + using var scope = NewDisposeScope(); + using var noGrad = torch.no_grad(); var batch = inputIds.shape[0]; var device = inputIds.device; var promptLength = (int)inputIds.shape[1]; var totalLen = promptLength + maxLen; - using (var noGrad = torch.no_grad()) + var prevPos = 0; + var eosReached = torch.tensor(new bool[batch], device: device); + torch.Tensor? logits = default; + var cache = new DynamicKVCache(); + if (promptLength == totalLen) + { + var input = new CasualLMModelInput(inputIds, attentionMask, pastKeyValuesLength: 0) + { + OverrideCache = cache, + }; + var output = this.Model.forward(input); + logits = output.Logits; + } + for (var curPos = promptLength; curPos != totalLen; curPos++) { - var prevPos = 0; - var eosReached = torch.tensor(new bool[batch], device: device); - torch.Tensor? logits = default; - var cache = new DynamicKVCache(); - if (promptLength == totalLen) + var input = new CasualLMModelInput(inputIds[.., prevPos..curPos], attentionMask[.., prevPos..curPos], pastKeyValuesLength: prevPos) + { + OverrideCache = cache, + }; + var output = this.Model.forward(input); + logits = output.Logits?.MoveToOtherDisposeScope(inputIds) ?? throw new InvalidOperationException("Logits is null"); + torch.Tensor nextToken; + if (temperature > 0) + { + var probs = torch.softmax(logits[.., -1] / temperature, dim: -1); + nextToken = this.SampleTopP(probs, topP); + } + else + { + nextToken = torch.argmax(logits[.., -1], dim: -1); + } + + nextToken = nextToken.reshape(-1); + inputIds = torch.cat([inputIds, nextToken.unsqueeze(1)], dim: -1).MoveToOtherDisposeScope(inputIds); + attentionMask = torch.cat([attentionMask, attentionMask.new_ones(attentionMask.shape[0], 1)], dim: -1); + foreach (var stopSequence in stopTokenSequence) { - var input = new CasualLMModelInput(inputIds, attentionMask, pastKeyValuesLength: 0) - { - OverrideCache = cache, - }; - var output = this.Model.forward(input); - logits = output.Logits; + // determine if the last n tokens are the stop sequence + var lastN = inputIds[.., ^stopSequence.Length..]; + var lastNMatch = lastN == torch.tensor(stopSequence, device: device); + eosReached |= lastNMatch.all(dim: -1); } - for (var curPos = promptLength; curPos != totalLen; curPos++) + if (eosReached.all().item()) { - var input = new CasualLMModelInput(inputIds[.., prevPos..curPos], attentionMask[.., prevPos..curPos], pastKeyValuesLength: prevPos) - { - OverrideCache = cache, - }; - var output = this.Model.forward(input); - logits = output.Logits ?? throw new InvalidOperationException("Logits is null"); - torch.Tensor nextToken; - if (temperature > 0) - { - var probs = torch.softmax(logits[.., -1] / temperature, dim: -1); - nextToken = this.SampleTopP(probs, topP); - } - else - { - nextToken = torch.argmax(logits[.., -1], dim: -1); - } - - nextToken = nextToken.reshape(-1); - inputIds = torch.cat([inputIds, nextToken.unsqueeze(1)], dim: -1); - attentionMask = torch.cat([attentionMask, attentionMask.new_ones(attentionMask.shape[0], 1)], dim: -1); - foreach (var stopSequence in stopTokenSequence) - { - // determine if the last n tokens are the stop sequence - var lastN = inputIds[.., ^stopSequence.Length..]; - var lastNMatch = lastN == torch.tensor(stopSequence, device: device); - eosReached |= lastNMatch.all(dim: -1); - } - if (eosReached.all().item()) - { - break; - } - - // pBar.Tick(curPos, message); - var nextTokenIds = nextToken.to_type(ScalarType.Int32).data().ToArray(); - var nextTokenStr = this.Tokenizer.Decode(nextTokenIds); - - prevPos = curPos; + break; } - if (echo) + yield return (nextToken.MoveToOuterDisposeScope(), logits[.., ^1].MoveToOuterDisposeScope()); + prevPos = curPos; + } + } + + public virtual ( + Tensor, // output token ids [batch_size, sequence_length] + Tensor // output logits [batch_size, sequence_length, vocab_size] + ) Generate( + Tensor inputIds, // input token ids [batch_size, sequence_length] + Tensor attentionMask, // attention mask [batch_size, sequence_length] + int[][] stopTokenSequence, + float temperature = 0.7f, + float topP = 0.9f, + int maxLen = 128) + { + using var scope = NewDisposeScope(); + Tensor? logits = null; + foreach (var (token, _logits) in this.GenerateStreaming(inputIds, attentionMask, stopTokenSequence, temperature, topP, maxLen)) + { + inputIds = torch.cat([inputIds, token.unsqueeze(1)], dim: -1).MoveToOtherDisposeScope(inputIds); + if (logits is null) { - // return entire inputIds and logits - return (inputIds.MoveToOuterDisposeScope(), logits!.MoveToOuterDisposeScope()); + logits = _logits; } else { - // return [batch_size, promptLength..] and [batch_size, promptLength.., vocab_size] - return (inputIds[.., promptLength..].MoveToOuterDisposeScope(), logits![.., promptLength..].MoveToOuterDisposeScope()); + logits = torch.cat([logits, _logits], dim: -1).MoveToOtherDisposeScope(inputIds); } } + + return (inputIds, logits ?? throw new InvalidOperationException("Logits is null")); } - public virtual string? Generate( + public virtual string Generate( string prompt, int maxLen = 128, float temperature = 0.7f, float topP = 0.9f, - string[]? stopSequences = null, - bool echo = false) + string[]? stopSequences = null) + { + var chunks = new List(); + + foreach (var chunk in this.GenerateStreaming(prompt, maxLen, temperature, topP, stopSequences)) + { + chunks.Add(chunk); + } + + return string.Join(string.Empty, chunks); + } + + + public virtual IEnumerable GenerateStreaming( + string prompt, + int maxLen = 128, + float temperature = 0.7F, + float topP = 0.9F, + string[]? stopSequences = null) { using var newScope = NewDisposeScope(); var inputIds = this.Tokenizer.EncodeToIds(prompt); var inputTensor = torch.tensor(inputIds.ToArray(), dtype: ScalarType.Int64, device: this.Device).unsqueeze(0); var attentionMask = torch.ones_like(inputTensor); - // set up stop token ids // stop token ids: [[eosId], [stopSequence1], [stopSequence2], ...] // when causal language model generates tokens, it will stop when it generates any token in stopSequences @@ -185,11 +240,16 @@ public virtual ( stopTokenIds = stopTokenIds.Where(ids => ids.Count() > 0).ToList(); - (var token, var _) = this.Generate(inputTensor, attentionMask, temperature: temperature, maxLen: maxLen, topP: topP, stopTokenSequence: stopTokenIds.ToArray(), echo: echo); - - var tokenIds = token[0].to_type(ScalarType.Int32).data().ToArray(); + foreach (var (token, _) in this.GenerateStreaming(inputTensor, attentionMask, stopTokenIds.ToArray(), temperature: temperature, maxLen: maxLen)) + { + var tokenIds = token[0].to_type(ScalarType.Int32).data().ToArray(); + var duplicateTokenString = this.Tokenizer.Decode(tokenIds.Concat(tokenIds)) ?? throw new InvalidOperationException("Failed to decode token ids"); + var tokenString = this.Tokenizer.Decode(tokenIds) ?? throw new InvalidOperationException("Failed to decode token ids"); + // replace the first occurrence of the token with the duplicate token + tokenString = duplicateTokenString.Substring(tokenString.Length); - return this.Tokenizer.Decode(tokenIds); + yield return tokenString; + } } protected torch.Tensor SampleTopP(torch.Tensor logits, float topP) diff --git a/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj b/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj index 19b1c156f2..f4cdc6333a 100644 --- a/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj +++ b/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj @@ -15,6 +15,7 @@ + diff --git a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMAgent.cs b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMAgent.cs index c4ff1a8b23..abe1e92716 100644 --- a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMAgent.cs +++ b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMAgent.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Runtime.CompilerServices; using System.Text; using System.Threading.Tasks; using AutoGen.Core; @@ -13,7 +14,7 @@ namespace Microsoft.ML.GenAI.Phi; -public class Phi3Agent : IAgent +public class Phi3Agent : IStreamingAgent { private const char Newline = '\n'; private readonly ICausalLMPipeline _pipeline; @@ -32,6 +33,46 @@ public Phi3Agent( public string Name { get; } public Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) + { + var input = BuildPrompt(messages); + var maxLen = options?.MaxToken ?? 1024; + var temperature = options?.Temperature ?? 0.7f; + var stopTokenSequence = options?.StopSequence ?? []; + stopTokenSequence = stopTokenSequence.Append("<|end|>").ToArray(); + + var output = _pipeline.Generate( + input, + maxLen: maxLen, + temperature: temperature, + stopSequences: stopTokenSequence) ?? throw new InvalidOperationException("Failed to generate a reply."); + + return Task.FromResult(new TextMessage(Role.Assistant, output, from: this.Name)); + } + +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + public async IAsyncEnumerable GenerateStreamingReplyAsync( +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously + IEnumerable messages, + GenerateReplyOptions? options = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var input = BuildPrompt(messages); + var maxLen = options?.MaxToken ?? 1024; + var temperature = options?.Temperature ?? 0.7f; + var stopTokenSequence = options?.StopSequence ?? []; + stopTokenSequence = stopTokenSequence.Append("<|end|>").ToArray(); + + foreach (var output in _pipeline.GenerateStreaming( + input, + maxLen: maxLen, + temperature: temperature, + stopSequences: stopTokenSequence)) + { + yield return new TextMessageUpdate(Role.Assistant, output, from: this.Name); + } + } + + private string BuildPrompt(IEnumerable messages) { var availableRoles = new[] { Role.System, Role.User, Role.Assistant }; if (messages.Any(m => m.GetContent() is null)) @@ -68,17 +109,6 @@ _ when message.GetRole() == Role.Assistant => $"<|assistant|>{Newline}{content}< sb.Append("<|assistant|>"); var input = sb.ToString(); - var maxLen = options?.MaxToken ?? 1024; - var temperature = options?.Temperature ?? 0.7f; - var stopTokenSequence = options?.StopSequence ?? []; - stopTokenSequence = stopTokenSequence.Append("<|end|>").ToArray(); - - var output = _pipeline.Generate( - input, - maxLen: maxLen, - temperature: temperature, - stopSequences: stopTokenSequence) ?? throw new InvalidOperationException("Failed to generate a reply."); - - return Task.FromResult(new TextMessage(Role.Assistant, output, from: this.Name)); + return input; } } diff --git a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMChatCompletionService.cs b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMChatCompletionService.cs index 4888d7a0f8..efe3089fdb 100644 --- a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMChatCompletionService.cs +++ b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMChatCompletionService.cs @@ -31,6 +31,28 @@ public async Task> GetChatMessageContentsAsync PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) + { + var prompt = BuildPrompt(chatHistory); + var reply = await _textGenerationService.GetTextContentAsync(prompt, executionSettings, kernel, cancellationToken); + return [new ChatMessageContent(AuthorRole.Assistant, reply.Text)]; + } + + public async IAsyncEnumerable GetStreamingChatMessageContentsAsync( + ChatHistory chatHistory, + PromptExecutionSettings? executionSettings = null, + Kernel? kernel = null, + [EnumeratorCancellation] + CancellationToken cancellationToken = default) + { + var prompt = BuildPrompt(chatHistory); + + await foreach (var reply in _textGenerationService.GetStreamingTextContentsAsync(prompt, executionSettings, kernel, cancellationToken)) + { + yield return new StreamingChatMessageContent(AuthorRole.Assistant, reply.Text); + } + } + + private string BuildPrompt(ChatHistory chatHistory) { // build prompt from chat history var sb = new StringBuilder(); @@ -57,25 +79,7 @@ public async Task> GetChatMessageContentsAsync } sb.Append("<|assistant|>"); - var reply = await _textGenerationService.GetTextContentAsync(sb.ToString(), executionSettings, kernel, cancellationToken); - return [new ChatMessageContent(AuthorRole.Assistant, reply.Text)]; - } - - public async IAsyncEnumerable GetStreamingChatMessageContentsAsync( - ChatHistory chatHistory, - PromptExecutionSettings? executionSettings = null, - Kernel? kernel = null, - [EnumeratorCancellation] - CancellationToken cancellationToken = default) - { - // CausalLMPipeline doesn't support streaming output yet - // here we simply implement this api using the synchronous version - var response = await GetChatMessageContentsAsync(chatHistory, executionSettings, kernel, cancellationToken); - - foreach (var item in response) - { - yield return new StreamingChatMessageContent(item.Role, item.Content); - } + return sb.ToString(); } } diff --git a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMTextGenerationService.cs b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMTextGenerationService.cs index a94c91fe0b..ac22b4f353 100644 --- a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMTextGenerationService.cs +++ b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMTextGenerationService.cs @@ -30,22 +30,23 @@ public Phi3CausalLMTextGenerationService(ICausalLMPipeline> GetTextContentsAsync(string prompt, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) { var temperature = executionSettings?.ExtensionData?["temperature"] as float? ?? 0.7f; - var maxToken = executionSettings?.ExtensionData?["max_token"] as int? ?? 100; - var stopTokenSequence = executionSettings?.ExtensionData?["stop_token_sequence"] as string[] ?? Array.Empty(); + var maxToken = executionSettings?.ExtensionData?["max_token"] as int? ?? 512; + var stopTokenSequence = executionSettings?.ExtensionData?["stop_token_sequence"] as List ?? new List(); var topP = executionSettings?.ExtensionData?["top_p"] as float? ?? 0.9f; - stopTokenSequence.Append("<|end|>"); - + stopTokenSequence.Add("<|end|>"); var response = _pipeline.Generate( prompt, maxToken, temperature, - stopSequences: stopTokenSequence, + stopSequences: stopTokenSequence.ToArray(), topP: topP); return Task.FromResult>([new TextContent(response)]); } +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously public async IAsyncEnumerable GetStreamingTextContentsAsync( +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously string prompt, PromptExecutionSettings? executionSettings = null, @@ -53,14 +54,20 @@ public async IAsyncEnumerable GetStreamingTextContentsAsyn [EnumeratorCancellation] CancellationToken cancellationToken = default) { - // CausalLMPipeline doesn't support streaming output yet - // here we simply implement this api using the synchronous version - - var response = await GetTextContentsAsync(prompt, executionSettings, kernel, cancellationToken); + var temperature = executionSettings?.ExtensionData?["temperature"] as float? ?? 0.7f; + var maxToken = executionSettings?.ExtensionData?["max_token"] as int? ?? 100; + var stopTokenSequence = executionSettings?.ExtensionData?["stop_token_sequence"] as string[] ?? Array.Empty(); + var topP = executionSettings?.ExtensionData?["top_p"] as float? ?? 0.9f; + stopTokenSequence.Append("<|end|>"); - foreach (var item in response) + foreach (var item in _pipeline.GenerateStreaming( + prompt, + maxToken, + temperature, + topP, + stopTokenSequence)) { - yield return new StreamingTextContent(item.Text); + yield return new StreamingTextContent(item); } } } diff --git a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3Tokenzier.cs b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3Tokenzier.cs index e204def0ec..11116cdb72 100644 --- a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3Tokenzier.cs +++ b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3Tokenzier.cs @@ -10,8 +10,6 @@ public class Phi3Tokenizer : Tokenizer { private readonly SentencePieceBpe _tokenizer; private readonly bool _addPrecedingSpace; - private readonly bool _addBeginningOfSentence; - private readonly bool _addEndOfSentence; private const string SystemSymbol = "<|system|>"; private const string UserSymbol = "<|user|>"; private const string AssistantSymbol = "<|assistant|>"; @@ -29,14 +27,10 @@ public class Phi3Tokenizer : Tokenizer }; public Phi3Tokenizer(string modelPath, - bool addPrecedingSpace = true, - bool addBeginningOfSentence = true, - bool addEndOfSentence = true) + bool addPrecedingSpace = true) { var modelStream = File.OpenRead(modelPath); this._addPrecedingSpace = addPrecedingSpace; - this._addBeginningOfSentence = addBeginningOfSentence; - this._addEndOfSentence = addEndOfSentence; this._tokenizer = (SentencePieceBpe)Tokenizer.CreateLlama(modelStream, false, false); } @@ -83,6 +77,28 @@ public override IReadOnlyList Encode(ReadOnlySpan text, out string? return tokens.ToArray(); } + public IReadOnlyList EncodeToIds( + string text, + bool addBeginningOfSentence, + bool addEndOfSentence, + bool considerPreTokenization = true, + bool considerNormalization = true) + { + var ids = this.EncodeToIds(text, considerPreTokenization: considerPreTokenization, considerNormalization: considerNormalization); + + if (addBeginningOfSentence) + { + ids = new int[] { this.BosId }.Concat(ids).ToArray(); + } + + if (addEndOfSentence) + { + ids = ids.Concat(new int[] { this.EosId }).ToArray(); + } + + return ids; + } + public override IReadOnlyList EncodeToIds(ReadOnlySpan text, bool considerPreTokenization = true, bool considerNormalization = true) { var input = text.ToString(); @@ -99,14 +115,14 @@ public override IReadOnlyList EncodeToIds(ReadOnlySpan text, bool con var subString = input.Substring(0, index); var subTokens = this._tokenizer.EncodeToIds(subString, addBeginningOfSentence: false, addEndOfSentence: false, considerPreTokenization: false, considerNormalization: true).ToArray(); // remove the first sub Token as it will always be '_' - tokens.AddRange(subTokens.Skip(1)); + tokens.AddRange(subTokens); tokens.Add(this._specialTokenMap[specialToken]); input = input.Remove(0, index + specialToken.Length); } tokens.AddRange(this._tokenizer.EncodeToIds(input, addBeginningOfSentence: false, addEndOfSentence: false, considerPreTokenization: false, considerNormalization: true).ToArray()); - return this._addBeginningOfSentence ? new int[] { this.BosId }.Concat(tokens).ToArray() : tokens.ToArray(); + return tokens.ToArray(); } public override IReadOnlyList EncodeToIds(ReadOnlySpan text, int maxTokenCount, out string? normalizedText, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true) @@ -159,36 +175,6 @@ public override int LastIndexOfTokenCount(ReadOnlySpan text, int maxTokenC var str = this._tokenizer.Decode(replacedIds) ?? throw new Exception("Failed to decode ids"); return str; - - //var tokens = new List(); - //foreach (var id in ids) - //{ - // if (_specialTokenMap.ContainsValue(id)) - // { - // tokens.Add(_specialTokenMap.First(x => x.Value == id).Key); - // } - // else - // { - // tokens.Add(this._tokenizer.MapIdToToken(id) ?? throw new Exception("Failed to map id to token")); - // } - //} - - //if (this._addBeginningOfSentence) - //{ - // tokens = tokens[1..].ToList(); - //} - - //var str = string.Join("", tokens); - - //// replace Dummy with whitespace - //str = str.Replace(SentencePieceNormalizer.DummyPrefix, ' '); - - //if (this._addPrecedingSpace) - //{ - // str = str.TrimStart(' '); - //} - - //return str; } public override int? MapTokenToId(ReadOnlySpan token) diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.TokenizerTest.approved.txt b/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.TokenizerTest.approved.txt index 442bde7bca..95b1fb630d 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.TokenizerTest.approved.txt +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.TokenizerTest.approved.txt @@ -2,4 +2,7 @@ 1, 18585, 29991, 2266, 526, 777, 5837, 304, 17545, 9892, 16397, 322, 8338, 265, 29888, 21211, 4208, 29901, 29871, 29896, 29889, 10765, 1648, 322, 8338, 265, 29888, 9216, 10597, 347, 29901, 3164, 355, 9892, 16397, 322, 8338, 265, 29888, 21211, 4208, 411, 777, 27274, 322, 298, 4992, 29889, 29871, 29906, 29889, 10765, 1648, 322, 8338, 265, 29888, 9216, 4497, 328, 29901, 23478, 269, 506, 287, 9892, 16397, 322, 8338, 265, 29888, 21211, 4208, 411, 777, 454, 3712, 3623, 625, 322, 298, 4992, 29889 1, 1724, 1048, 17069, 385, 29871, 29906, 29916, 718, 29871, 29941, 353, 29871, 29955, 6306, 29973 1, 29871, 13, 3981, 304, 29871, 29941, 13 +1, 32010 +1, 32007 +1, 32001 1, 32010, 29871, 13, 3981, 304, 29871, 29941, 32007, 29871, 13, 32001 diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/AutoGenTests.cs b/test/Microsoft.ML.GenAI.Phi.Tests/AutoGenTests.cs index 21ae811fda..e08e496eff 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/AutoGenTests.cs +++ b/test/Microsoft.ML.GenAI.Phi.Tests/AutoGenTests.cs @@ -30,14 +30,13 @@ public async Task ItGenerateTextReply() It.IsAny(), // max length It.IsAny(), // temperature It.IsAny(), // top_p - It.IsAny(), // stop sequence - It.IsAny())) // echo - .Callback((string prompt, int maxLen, float temperature, float topP, string[] stopSequences, bool echo) => + It.IsAny())) // stop sequence + .Callback((string prompt, int maxLen, float temperature, float topP, string[] stopSequences) => { // check prompt prompt.Should().Be("<|system|>\nyou are a helpful assistant<|end|>\n<|user|>\nhey<|end|>\n<|assistant|>"); }) - .Returns((string prompt, int maxLen, float temperature, float topP, string[] stopSequences, bool echo) => "hello"); + .Returns((string prompt, int maxLen, float temperature, float topP, string[] stopSequences) => "hello"); var agent = new Phi3Agent(pipeline, "assistant"); var reply = await agent.SendAsync("hey"); diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs b/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs index d83777a2f7..f2fe90a9ca 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs @@ -90,7 +90,7 @@ public void TokenizerTest() tokenizer.EosId.Should().Be(2); // test <|end|> - var endIds = tokenizer.EncodeToIds("<|end|>", considerPreTokenization: false, considerNormalization: false); + var endIds = tokenizer.EncodeToIds("<|end|>", addBeginningOfSentence: false, addEndOfSentence: false, considerPreTokenization: false, considerNormalization: false); endIds.Should().BeEquivalentTo(new int[] { 32007 }); var messages = new string[] @@ -107,10 +107,8 @@ public void TokenizerTest() var sb = new StringBuilder(); foreach (var message in messages) { - var tokenized = tokenizer.EncodeToIds(message, considerPreTokenization: true); - var decodedString = tokenizer.Decode(tokenized); - decodedString.Should().Be(message); - var tokenizedStr = string.Join(", ", tokenized.Select(x => x.ToString())); + var tokenizeIds = tokenizer.EncodeToIds(message, true, false, considerPreTokenization: true); + var tokenizedStr = string.Join(", ", tokenizeIds.Select(x => x.ToString())); sb.AppendLine(tokenizedStr); } diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/SemanticKernelTests.cs b/test/Microsoft.ML.GenAI.Phi.Tests/SemanticKernelTests.cs index ec61de4601..09c11c8886 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/SemanticKernelTests.cs +++ b/test/Microsoft.ML.GenAI.Phi.Tests/SemanticKernelTests.cs @@ -26,21 +26,20 @@ public SemanticKernelTests(ITestOutputHelper helper) [Fact] public async Task ItAddPhi3CausalLMChatCompletionServiceTestAsync() { - var pipeline = Mock.Of>(MockBehavior.Loose); + var pipeline = Mock.Of>(); // mock generate api Mock.Get(pipeline).Setup(p => p.Generate( It.IsAny(), // prompt It.IsAny(), // max length It.IsAny(), // temperature It.IsAny(), // top_p - It.IsAny(), // stop sequence - It.IsAny())) // echo - .Callback((string prompt, int maxLen, float temperature, float topP, string[] stopSequences, bool echo) => + It.IsAny())) // stop sequence + .Callback((string prompt, int maxLen, float temperature, float topP, string[] stopSequences) => { // check prompt prompt.Should().Be("<|system|>\nyou are a helpful assistant<|end|>\n<|user|>\nhey<|end|>\n<|assistant|>"); }) - .Returns((string prompt, int maxLen, float temperature, float topP, string[] stopSequences, bool echo) => "hello"); + .Returns((string prompt, int maxLen, float temperature, float topP, string[] stopSequences) => "hello"); var kernel = Kernel.CreateBuilder() .AddPhi3AsChatCompletion(pipeline) @@ -63,21 +62,20 @@ public async Task ItAddPhi3CausalLMChatCompletionServiceTestAsync() [Fact] public async Task ItAddPhi3CausalLMTextGenerationServiceTestAsync() { - var pipeline = Mock.Of>(MockBehavior.Loose); + var pipeline = Mock.Of>(); // mock generate api Mock.Get(pipeline).Setup(p => p.Generate( It.IsAny(), // prompt It.IsAny(), // max length It.IsAny(), // temperature It.IsAny(), // top_p - It.IsAny(), // stop sequence - It.IsAny())) // echo - .Callback((string prompt, int maxLen, float temperature, float topP, string[] stopSequences, bool echo) => + It.IsAny())) // stop sequence + .Callback((string prompt, int maxLen, float temperature, float topP, string[] stopSequences) => { // check prompt prompt.Should().Be("test"); }) - .Returns((string prompt, int maxLen, float temperature, float topP, string[] stopSequences, bool echo) => "hello"); + .Returns((string prompt, int maxLen, float temperature, float topP, string[] stopSequences) => "hello"); var kernel = Kernel.CreateBuilder() .AddPhi3AsTextGeneration(pipeline) From 3db5c611abed18e5f0ce134992ca6eb60af6245e Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Fri, 28 Jun 2024 11:16:13 -0700 Subject: [PATCH 09/41] use tokenzier model from testTokenizer package --- docs/samples/Microsoft.ML.GenAI.Samples/Program.cs | 2 +- .../Microsoft.ML.GenAI.Phi.Tests.csproj | 1 + test/Microsoft.ML.GenAI.Phi.Tests/Phi2Test.cs | 2 +- test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs | 2 +- 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs index 63e18e5f42..5e4355e595 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs @@ -1,4 +1,4 @@ // See https://aka.ms/new-console-template for more information using Microsoft.ML.GenAI.Samples.Phi3Mini; -await SemanticKernelSample.RunTextGenerationSample(); +await AutoGenSample.RunAsync(); diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj b/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj index 5d1d22f4f0..0fbe73d1d0 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj @@ -18,6 +18,7 @@ + diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Test.cs b/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Test.cs index 44d1d74d7e..3d99cd986d 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Test.cs +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Test.cs @@ -44,7 +44,7 @@ public void LoadSafeTensorShapeTest() [UseApprovalSubdirectory("Approvals")] public void TokenizerTest() { - var modelWeightFolder = "C:\\Users\\xiaoyuz\\source\\repos\\phi-2"; + var modelWeightFolder = Path.Join("Phi-2"); var tokenizer = Tokenizer.CreatePhi2(modelWeightFolder, addBeginOfSentence: true); tokenizer.EndOfSentenceId.Should().Be(50256); tokenizer.BeginningOfSentenceId.Should().Be(50256); diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs b/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs index f2fe90a9ca..c427114d95 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs @@ -84,7 +84,7 @@ public void Phi3Mini128KLayerSizeTest() [UseApprovalSubdirectory("Approvals")] public void TokenizerTest() { - var modelWeightFolder = "C:\\Users\\xiaoyuz\\source\\repos\\Phi-3-mini-4k-instruct"; + var modelWeightFolder = Path.Join("Llama"); var tokenizer = Phi3Tokenizer.FromPretrained(modelWeightFolder); tokenizer.BosId.Should().Be(1); tokenizer.EosId.Should().Be(2); From d72fbcea216b35558720037bb0d683d6c5507d2d Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Fri, 28 Jun 2024 11:38:24 -0700 Subject: [PATCH 10/41] use defaults --- .../Pipeline/CasualLMModelInput.cs | 24 ++++++--- .../Pipeline/CasualLMModelOutput.cs | 15 ++++-- .../Pipeline/CausalLMPipeline.cs | 49 +++++++++++-------- 3 files changed, 57 insertions(+), 31 deletions(-) diff --git a/src/Microsoft.ML.GenAI.Core/Pipeline/CasualLMModelInput.cs b/src/Microsoft.ML.GenAI.Core/Pipeline/CasualLMModelInput.cs index 31b7530b88..49fcfef627 100644 --- a/src/Microsoft.ML.GenAI.Core/Pipeline/CasualLMModelInput.cs +++ b/src/Microsoft.ML.GenAI.Core/Pipeline/CasualLMModelInput.cs @@ -8,15 +8,25 @@ namespace Microsoft.ML.GenAI.Core; public class CasualLMModelInput { + internal static class Defaults + { + internal const Tensor? AttentionMask = null; + internal const Tensor? PositionIds = null; + internal const int PastKeyValuesLength = 0; + internal const Tensor? InputsEmbeds = null; + internal const bool UseCache = false; + internal const bool OutputAttentions = false; + internal const bool OutputHiddenStates = false; + } public CasualLMModelInput( Tensor inputIds, - Tensor? attentionMask = null, - Tensor? positionIds = null, - int pastKeyValuesLength = 0, - Tensor? inputsEmbeds = null, - bool useCache = false, - bool outputAttentions = false, - bool outputHiddenStates = false) + Tensor? attentionMask = Defaults.AttentionMask, + Tensor? positionIds = Defaults.PositionIds, + int pastKeyValuesLength = Defaults.PastKeyValuesLength, + Tensor? inputsEmbeds = Defaults.InputsEmbeds, + bool useCache = Defaults.UseCache, + bool outputAttentions = Defaults.OutputAttentions, + bool outputHiddenStates = Defaults.OutputHiddenStates) { this.InputIds = inputIds; this.AttentionMask = attentionMask; diff --git a/src/Microsoft.ML.GenAI.Core/Pipeline/CasualLMModelOutput.cs b/src/Microsoft.ML.GenAI.Core/Pipeline/CasualLMModelOutput.cs index f3ab2c5041..afaa84e778 100644 --- a/src/Microsoft.ML.GenAI.Core/Pipeline/CasualLMModelOutput.cs +++ b/src/Microsoft.ML.GenAI.Core/Pipeline/CasualLMModelOutput.cs @@ -8,12 +8,19 @@ namespace Microsoft.ML.GenAI.Core; public class CasualLMModelOutput { + internal static class Defaults + { + internal const Tensor? Logits = null; + internal const Tensor[]? AllHiddenStates = null; + internal const Tensor[]? Attentions = null; + internal const IKVCache? Cache = null; + } public CasualLMModelOutput( Tensor lastHiddenState, - Tensor? logits = null, - Tensor[]? allHiddenStates = null, - Tensor[]? attentions = null, - IKVCache? cache = null) + Tensor? logits = Defaults.Logits, + Tensor[]? allHiddenStates = Defaults.AllHiddenStates, + Tensor[]? attentions = Defaults.Attentions, + IKVCache? cache = Defaults.Cache) { this.LastHiddenState = lastHiddenState; this.AllHiddenStates = allHiddenStates; diff --git a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs index 417d1d87a2..abb08effa8 100644 --- a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs +++ b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs @@ -27,33 +27,33 @@ public interface ICausalLMPipeline { string Generate( string prompt, - int maxLen = 128, - float temperature = 0.7F, - float topP = 0.9F, - string[]? stopSequences = null); + int maxLen, + float temperature, + float topP, + string[]? stopSequences); IEnumerable GenerateStreaming( string prompt, - int maxLen = 128, - float temperature = 0.7F, - float topP = 0.9F, - string[]? stopSequences = null); + int maxLen, + float temperature, + float topP, + string[]? stopSequences); (Tensor, Tensor) Generate( Tensor inputIds, Tensor attentionMask, int[][] stopTokenSequence, - float temperature = 0.7F, - float topP = 0.9F, - int maxLen = 128); + float temperature, + float topP, + int maxLen); IEnumerable<(Tensor, Tensor)> GenerateStreaming( Tensor inputIds, Tensor attentionMask, int[][] stopTokenSequence, - float temperature = 0.7F, - float topP = 0.9F, - int maxLen = 128); + float temperature, + float topP, + int maxLen); } public class CausalLMPipeline : CausalLMPipeline, ICausalLMPipeline @@ -63,7 +63,7 @@ public class CausalLMPipeline : CausalLMPipeline, ICausalLMP public CausalLMPipeline( TTokenizer tokenizer, TModel model, - string device = "cpu") + string device = Defaults.Device) : base(tokenizer, model, device) { } @@ -75,10 +75,19 @@ public CausalLMPipeline( public class CausalLMPipeline : ICausalLMPipeline { + internal static class Defaults + { + internal const string Device = "cpu"; + internal const float Temperature = 0.7F; + internal const float TopP = 0.9F; + internal const int MaxLen = 128; + internal const string[]? StopSequence = null; + } + public CausalLMPipeline( Tokenizer tokenizer, nn.Module model, - string device = "cpu") + string device = Defaults.Device) { this.Tokenizer = tokenizer; this.Model = model; @@ -108,9 +117,9 @@ protected private CausalLMPipeline() Tensor inputIds, Tensor attentionMask, int[][] stopTokenSequence, - float temperature = 0.7F, - float topP = 0.9F, - int maxLen = 128) + float temperature = Defaults.Temperature, + float topP = Defaults.TopP, + int maxLen = Defaults.MaxLen) { using var scope = NewDisposeScope(); using var noGrad = torch.no_grad(); @@ -223,7 +232,7 @@ public virtual IEnumerable GenerateStreaming( int maxLen = 128, float temperature = 0.7F, float topP = 0.9F, - string[]? stopSequences = null) + string[]? stopSequences = Defaults.StopSequence) { using var newScope = NewDisposeScope(); var inputIds = this.Tokenizer.EncodeToIds(prompt); From 9cc9a0c2165515e498b2f7d2bca4540f41c2c994 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Fri, 28 Jun 2024 13:14:55 -0700 Subject: [PATCH 11/41] add quantize linear --- Microsoft.ML.sln | 13 +- .../Phi3Mini/AutoGenSample.cs | 2 +- .../Phi3Mini/Utils.cs | 14 +- .../Extension/ModuleExtension.cs | 47 +- .../Microsoft.ML.GenAI.Core.csproj | 1 + .../Module/GenAILinear.cs | 14 +- .../Module/IQuantizeModule.cs | 4 +- .../Module/QuantizedLinear.cs | 208 ++++++++ .../Pipeline/CausalLMPipeline.cs | 28 +- .../Module/Phi3Attention.cs | 8 +- src/Microsoft.ML.GenAI.Phi/Module/Phi3MLP.cs | 8 +- .../Microsoft.ML.GenAI.Core.Tests.csproj | 31 ++ .../QuantizedLinearTests.cs | 134 ++++++ ...i3Mini4KInt4QuantizeShapeTest.approved.txt | 451 ++++++++++++++++++ ...i3Mini4KInt8QuantizeShapeTest.approved.txt | 451 ++++++++++++++++++ ...i3Mini4KInt8QuantizeShapeTest.received.txt | 451 ++++++++++++++++++ .../Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs | 28 ++ 17 files changed, 1857 insertions(+), 36 deletions(-) create mode 100644 src/Microsoft.ML.GenAI.Core/Module/QuantizedLinear.cs create mode 100644 test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj create mode 100644 test/Microsoft.ML.GenAI.Core.Tests/QuantizedLinearTests.cs create mode 100644 test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Mini4KInt4QuantizeShapeTest.approved.txt create mode 100644 test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Mini4KInt8QuantizeShapeTest.approved.txt create mode 100644 test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Mini4KInt8QuantizeShapeTest.received.txt diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln index c30f67410a..d3985d1777 100644 --- a/Microsoft.ML.sln +++ b/Microsoft.ML.sln @@ -182,7 +182,9 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.GenAI.Phi", "s EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.GenAI.Phi.Tests", "test\Microsoft.ML.GenAI.Phi.Tests\Microsoft.ML.GenAI.Phi.Tests.csproj", "{867FFC34-DFA7-400F-B9BB-85158326CE08}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.ML.GenAI.Samples", "docs\samples\Microsoft.ML.GenAI.Samples\Microsoft.ML.GenAI.Samples.csproj", "{1D4AD9A3-19AF-432B-889D-A63FE6D7BD47}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.GenAI.Samples", "docs\samples\Microsoft.ML.GenAI.Samples\Microsoft.ML.GenAI.Samples.csproj", "{1D4AD9A3-19AF-432B-889D-A63FE6D7BD47}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.ML.GenAI.Core.Tests", "test\Microsoft.ML.GenAI.Core.Tests\Microsoft.ML.GenAI.Core.Tests.csproj", "{14AB0804-D4CE-4634-B544-5A8587620783}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution @@ -868,6 +870,14 @@ Global {1D4AD9A3-19AF-432B-889D-A63FE6D7BD47}.Release|Any CPU.Build.0 = Release|Any CPU {1D4AD9A3-19AF-432B-889D-A63FE6D7BD47}.Release|x64.ActiveCfg = Release|Any CPU {1D4AD9A3-19AF-432B-889D-A63FE6D7BD47}.Release|x64.Build.0 = Release|Any CPU + {14AB0804-D4CE-4634-B544-5A8587620783}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {14AB0804-D4CE-4634-B544-5A8587620783}.Debug|Any CPU.Build.0 = Debug|Any CPU + {14AB0804-D4CE-4634-B544-5A8587620783}.Debug|x64.ActiveCfg = Debug|Any CPU + {14AB0804-D4CE-4634-B544-5A8587620783}.Debug|x64.Build.0 = Debug|Any CPU + {14AB0804-D4CE-4634-B544-5A8587620783}.Release|Any CPU.ActiveCfg = Release|Any CPU + {14AB0804-D4CE-4634-B544-5A8587620783}.Release|Any CPU.Build.0 = Release|Any CPU + {14AB0804-D4CE-4634-B544-5A8587620783}.Release|x64.ActiveCfg = Release|Any CPU + {14AB0804-D4CE-4634-B544-5A8587620783}.Release|x64.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -958,6 +968,7 @@ Global {694BF884-B2E4-4E1C-9342-0564BAAC4575} = {09EADF06-BE25-4228-AB53-95AE3E15B530} {867FFC34-DFA7-400F-B9BB-85158326CE08} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} {1D4AD9A3-19AF-432B-889D-A63FE6D7BD47} = {DA452A53-2E94-4433-B08C-041EDEC729E6} + {14AB0804-D4CE-4634-B544-5A8587620783} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D} diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/AutoGenSample.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/AutoGenSample.cs index be26cc035e..17ce52cb10 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/AutoGenSample.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/AutoGenSample.cs @@ -26,7 +26,7 @@ public static async Task RunAsync() torch.manual_seed(1); torch.set_default_dtype(defaultType); var weightFolder = @"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct"; - var pipeline = Utils.LoadPhi3Mini4KFromFolder(weightFolder, device); + var pipeline = Utils.LoadPhi3Mini4KFromFolder(weightFolder, device, quantizeToInt4: true); // agent var agent = new Phi3Agent(pipeline, "assistant") diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs index 33769d6330..4a4c108749 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs @@ -19,13 +19,25 @@ public static CausalLMPipeline LoadPhi3Mini4KFro string device = "cuda", int modelSizeOnCudaInGB = 16, int modelSizeOnMemoryInGB = 64, - int modelSizeOnDiskInGB = 200) + int modelSizeOnDiskInGB = 200, + bool quantizeToInt8 = false, + bool quantizeToInt4 = false) { var defaultType = ScalarType.Float16; Console.WriteLine("Loading Phi3 from huggingface model weight folder"); var timer = System.Diagnostics.Stopwatch.StartNew(); var model = Phi3ForCasualLM.FromPretrained(weightFolder, device: device, torchDtype: defaultType, checkPointName: "model.safetensors.index.json"); var tokenizer = Phi3Tokenizer.FromPretrained(weightFolder); + + if (quantizeToInt8) + { + model.ToInt8QuantizeModule(); + } + else if (quantizeToInt4) + { + model.ToInt4QuantizeModule(); + } + var deviceSizeMap = new Dictionary { ["cuda:0"] = modelSizeOnCudaInGB * 1024 * 1024 * 1024, diff --git a/src/Microsoft.ML.GenAI.Core/Extension/ModuleExtension.cs b/src/Microsoft.ML.GenAI.Core/Extension/ModuleExtension.cs index a3fd98b1f3..18633728a5 100644 --- a/src/Microsoft.ML.GenAI.Core/Extension/ModuleExtension.cs +++ b/src/Microsoft.ML.GenAI.Core/Extension/ModuleExtension.cs @@ -56,19 +56,60 @@ public static Dictionary GetSizeForEachDynamicLayerInBytes(this nn } } - public static void ToQuantizedModule( + /// + /// Quantize the module using zero-point int8 quantization. + /// + /// + /// + public static void ToInt8QuantizeModule( + this T model) + where T : nn.Module + { + if (model is IQuantizeModule quantized) + { + quantized.Int8(); + + return; + } + + foreach (var (_, value) in model.named_children()) + { + if (value is IQuantizeModule quantizeModule) + { + quantizeModule.Int8(); + } + else + { + value.ToInt8QuantizeModule(); + } + } + } + + /// + /// Quantize the module using zero-point int4 quantization. + /// + /// + /// + public static void ToInt4QuantizeModule( this T model) where T : nn.Module { + if (model is IQuantizeModule quantized) + { + quantized.Int4(); + + return; + } + foreach (var (_, value) in model.named_children()) { if (value is IQuantizeModule quantizeModule) { - quantizeModule.Quantize(); + quantizeModule.Int4(); } else { - value.ToQuantizedModule(); + value.ToInt4QuantizeModule(); } } } diff --git a/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj b/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj index 17d1b45dbb..9f358d9914 100644 --- a/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj +++ b/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj @@ -19,6 +19,7 @@ + diff --git a/src/Microsoft.ML.GenAI.Core/Module/GenAILinear.cs b/src/Microsoft.ML.GenAI.Core/Module/GenAILinear.cs index d206c0dfa2..c59fd9f38d 100644 --- a/src/Microsoft.ML.GenAI.Core/Module/GenAILinear.cs +++ b/src/Microsoft.ML.GenAI.Core/Module/GenAILinear.cs @@ -8,12 +8,12 @@ namespace Microsoft.ML.GenAI; internal class GenAILinear : nn.Module { -#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format - private readonly Tensor weight; - private readonly Tensor? bias; -#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format - private readonly int _inFeatures; - private readonly int _outFeatures; +#pragma warning disable MSML_GeneralName // This name should be PascalCased + protected Tensor? weight; + protected Tensor? bias; + protected readonly int _inFeatures; + protected readonly int _outFeatures; +#pragma warning restore MSML_GeneralName // This name should be PascalCased public GenAILinear(int inFeatures, int outFeatures, bool hasBias = true, ScalarType dtype = ScalarType.Float32, string? device = null) : base(nameof(GenAILinear)) @@ -39,7 +39,7 @@ public override Tensor forward(Tensor input) // use float32 var input2 = input.to_type(ScalarType.Float32); - var weight2 = this.weight.to_type(ScalarType.Float32); + var weight2 = this.weight!.to_type(ScalarType.Float32); var result = torch.matmul(input2, weight2.t()); if (this.bias is not null) diff --git a/src/Microsoft.ML.GenAI.Core/Module/IQuantizeModule.cs b/src/Microsoft.ML.GenAI.Core/Module/IQuantizeModule.cs index 164936f3d7..57c0b7620f 100644 --- a/src/Microsoft.ML.GenAI.Core/Module/IQuantizeModule.cs +++ b/src/Microsoft.ML.GenAI.Core/Module/IQuantizeModule.cs @@ -6,5 +6,7 @@ namespace Microsoft.ML.GenAI.Core; public interface IQuantizeModule { - public void Quantize(); + public void Int8(); + + public void Int4(); } diff --git a/src/Microsoft.ML.GenAI.Core/Module/QuantizedLinear.cs b/src/Microsoft.ML.GenAI.Core/Module/QuantizedLinear.cs new file mode 100644 index 0000000000..268ac0a4a4 --- /dev/null +++ b/src/Microsoft.ML.GenAI.Core/Module/QuantizedLinear.cs @@ -0,0 +1,208 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. +using System; +using Microsoft.ML.GenAI.Core; +using TorchSharp; +using static TorchSharp.torch; +namespace Microsoft.ML.GenAI; + +internal class QuantizedLinear : GenAILinear, IQuantizeModule +{ + public QuantizedLinear(int inFeatures, int outFeatures, bool hasBias = true, ScalarType dtype = ScalarType.Float32, string? device = null) + : base(inFeatures, outFeatures, hasBias, dtype, device) + { + } + + public void Int8() + { + if (this.weight is null) + { + throw new Exception("Weight is not initialized"); + } + + if (this.weight.device_type != DeviceType.META) + { + // if weight is not on meta device, this means that weight and bias are already loaded + // so we can quantize them in memory + + var timer = new System.Diagnostics.Stopwatch(); + timer.Start(); + // scale and zero point on vector-wise + // scale = 255 / max(weight, axis=1) - min(weight, axis=1) + var scale = 255 / (torch.max(this.weight, 1).values - torch.min(this.weight, 1).values); + + // zero point = - scale * min(weight, axis=1) - 128 + var zeroPoint = -scale * torch.min(this.weight, 1).values - 128; + // round zero point to nearest integer + zeroPoint = torch.round(zeroPoint).to(torch.int8); + + // assert zero point is in range [-128, 127] + //if (torch.any(this.zeroPoint < -128).item() || torch.any(this.zeroPoint > 127).item()) + //{ + // throw new Exception("Zero point is out of range [-128, 127]"); + //} + + // quantize weight + var eightBitWeight = torch.round(this.weight * scale.view(-1, 1) + zeroPoint.view(-1, 1)).to(torch.int8); + + // assert weight is in range [-128, 127] + //if (torch.any(this._8bitWeight < -128).item() || torch.any(this._8bitWeight > 127).item()) + //{ + // throw new Exception("Weight is out of range [-128, 127]"); + //} + timer.Stop(); + // dispose float32 weight + this.weight.Dispose(); + this.weight = null; + this._internal_buffers.Remove("weight"); + this.register_buffer("8bit_weight", eightBitWeight); + this.register_buffer("zeroPoint", zeroPoint); + this.register_buffer("scale", scale); + } + else + { + // if weight is on meta device, then we just need to create the placeholder for 8bit_weight, zeroPoint and scale + var eightBitWeight = torch.zeros(this.weight.shape, dtype: torch.int8); + var zeroPoint = torch.zeros(this.weight.shape[0], dtype: torch.int8); + var scale = torch.zeros(this.weight.shape[0], dtype: torch.float32); + + this._internal_buffers.Remove("weight"); + this.weight = null; + this.register_buffer("8bit_weight", eightBitWeight); + this.register_buffer("zeroPoint", zeroPoint); + this.register_buffer("scale", scale); + } + } +#pragma warning disable MSML_GeneralName // This name should be PascalCased + public override Tensor forward(Tensor input) +#pragma warning restore MSML_GeneralName // This name should be PascalCased + { + if (this._internal_buffers.ContainsKey("weight")) + { + return base.forward(input); + } + else if (this._internal_buffers.ContainsKey("8bit_weight")) + { + // 8bit quantization + using var dispose = torch.NewDisposeScope(); + var weight = this.get_buffer("8bit_weight").to(ScalarType.Float32); + var zeroPoint = this.get_buffer("zeroPoint").to(ScalarType.Float32); + var scale = this.get_buffer("scale").to(ScalarType.Float32); + var restoreWeight = (weight - zeroPoint.view(-1, 1)) / scale.view(-1, 1); + // use float32 + var result = torch.matmul(input.to(ScalarType.Float32), restoreWeight.T); + + if (this.bias is not null) + { + result = result + this.bias.to_type(ScalarType.Float32); + } + + //result.Peek("result"); + return result.to_type(input.dtype).MoveToOuterDisposeScope(); + } + else if (this._internal_buffers.ContainsKey("4bit_weight")) + { + using var dispose = torch.NewDisposeScope(); + var weight = this.get_buffer("4bit_weight"); + var weightLower = weight % 16; + var weightUpper = weight / 16; + weight = torch.cat([weightUpper, weightLower], 0).to(ScalarType.Float32); + weight = weight.view(this._outFeatures, this._inFeatures); + weight -= 8; + var zeroPoint = this.get_buffer("zeroPoint"); + var zeroPointLower = zeroPoint % 16; + var zeroPointUpper = zeroPoint / 16; + zeroPoint = torch.cat([zeroPointUpper, zeroPointLower], 0).to(ScalarType.Float32); + zeroPoint -= 8; + var scale = this.get_buffer("scale").to(ScalarType.Float32); + var restoreWeight = (weight - zeroPoint.view(-1, 1)) / scale.view(-1, 1); + // use float32 + var result = torch.matmul(input.to(ScalarType.Float32), restoreWeight.T); + + if (this.bias is not null) + { + result = result + this.bias.to_type(ScalarType.Float32); + } + + //result.Peek("result"); + return result.to_type(input.dtype).MoveToOuterDisposeScope(); + } + else + { + throw new Exception("Quantization is not done yet"); + } + } + + public void Int4() + { + if (this.weight is null) + { + throw new Exception("Weight is not initialized"); + } + var placeHolderDim = this._outFeatures / 2 + this._outFeatures % 2; + var fourBitWeightDim = this.weight.size(0) * this.weight.size(1); + var fourBitWeightPlaceHolderDim = Convert.ToInt32(fourBitWeightDim / 2 + fourBitWeightDim % 2); + if (this.weight.device_type != DeviceType.META) + { + using var scope = NewDisposeScope(); + var timer = new System.Diagnostics.Stopwatch(); + timer.Start(); + // scale and zero point on vector-wise + // scale = 15 / max(weight, axis=1) - min(weight, axis=1) + var scale = 15 / (torch.max(this.weight, 1).values - torch.min(this.weight, 1).values); + + // zero point = - scale * min(weight, axis=1) - 8 + var zeroPoint = -scale * torch.min(this.weight, 1).values - 8; + // round zero point to nearest integer + zeroPoint = torch.round(zeroPoint); + var fourBitWeight = torch.round(this.weight * scale.view(-1, 1) + zeroPoint.view(-1, 1)).to(torch.int8); + + zeroPoint = (zeroPoint + 8).to(torch.uint8); + fourBitWeight = (fourBitWeight + 8).view(-1).to(torch.uint8); + + // torch doesn't provide int4, so we use int8 as placeholder + // and foreach int8, we save two int4, e.g. 0b1010 -> 0b10, 0b10 + var zpPlaceHolder = zeroPoint[..placeHolderDim]; + zpPlaceHolder = zpPlaceHolder * 16 + zeroPoint[placeHolderDim..]; + + // assert zero point is in range [-128, 127] + //if (torch.any(this.zeroPoint < -128).item() || torch.any(this.zeroPoint > 127).item()) + //{ + // throw new Exception("Zero point is out of range [-128, 127]"); + //} + + // quantize weight + var fourBitWeightPlaceHolder = fourBitWeight[..fourBitWeightPlaceHolderDim]; + fourBitWeightPlaceHolder = fourBitWeightPlaceHolder * 16 + fourBitWeight[fourBitWeightPlaceHolderDim..]; + + // assert weight is in range [-128, 127] + //if (torch.any(this._8bitWeight < -128).item() || torch.any(this._8bitWeight > 127).item()) + //{ + // throw new Exception("Weight is out of range [-128, 127]"); + //} + + // dispose float32 weight + this.weight.Dispose(); + + this._internal_buffers.Remove("weight"); + this.register_buffer("4bit_weight", fourBitWeightPlaceHolder.MoveToOuterDisposeScope()); + this.register_buffer("zeroPoint", zpPlaceHolder.MoveToOuterDisposeScope()); + this.register_buffer("scale", scale.MoveToOuterDisposeScope()); + timer.Stop(); + } + else + { + // if weight is on meta device, then we just need to create the placeholder for 8bit_weight, zeroPoint and scale + var fourBitWeight = torch.zeros(fourBitWeightPlaceHolderDim, dtype: torch.int8); + var zeroPoint = torch.zeros(placeHolderDim, dtype: torch.int8); + var scale = torch.zeros(this.weight.shape[0], dtype: torch.float32); + + this._internal_buffers.Remove("weight"); + this.weight = null; + this.register_buffer("4bit_weight", fourBitWeight); + this.register_buffer("zeroPoint", zeroPoint); + this.register_buffer("scale", scale); + } + } +} diff --git a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs index abb08effa8..7122878a9b 100644 --- a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs +++ b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs @@ -27,33 +27,33 @@ public interface ICausalLMPipeline { string Generate( string prompt, - int maxLen, - float temperature, - float topP, - string[]? stopSequences); + int maxLen = CausalLMPipeline.Defaults.MaxLen, + float temperature = CausalLMPipeline.Defaults.Temperature, + float topP = CausalLMPipeline.Defaults.TopP, + string[]? stopSequences = CausalLMPipeline.Defaults.StopSequence); IEnumerable GenerateStreaming( string prompt, - int maxLen, - float temperature, - float topP, - string[]? stopSequences); + int maxLen = CausalLMPipeline.Defaults.MaxLen, + float temperature = CausalLMPipeline.Defaults.Temperature, + float topP = CausalLMPipeline.Defaults.TopP, + string[]? stopSequences = CausalLMPipeline.Defaults.StopSequence); (Tensor, Tensor) Generate( Tensor inputIds, Tensor attentionMask, int[][] stopTokenSequence, - float temperature, - float topP, - int maxLen); + float temperature = CausalLMPipeline.Defaults.Temperature, + float topP = CausalLMPipeline.Defaults.TopP, + int maxLen = CausalLMPipeline.Defaults.MaxLen); IEnumerable<(Tensor, Tensor)> GenerateStreaming( Tensor inputIds, Tensor attentionMask, int[][] stopTokenSequence, - float temperature, - float topP, - int maxLen); + float temperature = CausalLMPipeline.Defaults.Temperature, + float topP = CausalLMPipeline.Defaults.TopP, + int maxLen = CausalLMPipeline.Defaults.MaxLen); } public class CausalLMPipeline : CausalLMPipeline, ICausalLMPipeline diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi3Attention.cs b/src/Microsoft.ML.GenAI.Phi/Module/Phi3Attention.cs index c51b0eef0b..72c7c8946a 100644 --- a/src/Microsoft.ML.GenAI.Phi/Module/Phi3Attention.cs +++ b/src/Microsoft.ML.GenAI.Phi/Module/Phi3Attention.cs @@ -75,8 +75,8 @@ internal class Phi3Attention : nn.Module? _ropeScaling; #pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format - private readonly GenAILinear o_proj; - private readonly GenAILinear qkv_proj; + private readonly QuantizedLinear o_proj; + private readonly QuantizedLinear qkv_proj; private nn.Module rotary_emb = null!; #pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format @@ -99,8 +99,8 @@ public Phi3Attention(Phi3Config config, int layerIdx) Contract.Assert(this._hiddenSize % (this._headDim * this._numHeads) == 0, "hidden_size must be divisible by num_heads"); var opSize = this._numHeads * this._headDim + 2 * (this._numKeyValueHeads * this._headDim); - this.o_proj = new GenAILinear(this._numHeads * this._headDim, this._hiddenSize, hasBias: false, dtype: config.DType); - this.qkv_proj = new GenAILinear(this._hiddenSize, opSize, hasBias: false, dtype: config.DType); + this.o_proj = new QuantizedLinear(this._numHeads * this._headDim, this._hiddenSize, hasBias: false, dtype: config.DType); + this.qkv_proj = new QuantizedLinear(this._hiddenSize, opSize, hasBias: false, dtype: config.DType); this.InitRope(); } diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi3MLP.cs b/src/Microsoft.ML.GenAI.Phi/Module/Phi3MLP.cs index 752ea9dd2b..745c000800 100644 --- a/src/Microsoft.ML.GenAI.Phi/Module/Phi3MLP.cs +++ b/src/Microsoft.ML.GenAI.Phi/Module/Phi3MLP.cs @@ -17,8 +17,8 @@ internal class Phi3MLP : torch.nn.Module #pragma warning restore MSML_GeneralName // This name should be PascalCased { #pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format - private readonly GenAILinear gate_up_proj; - private readonly GenAILinear down_proj; + private readonly QuantizedLinear gate_up_proj; + private readonly QuantizedLinear down_proj; private readonly torch.nn.Module activation_fn; #pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format @@ -30,8 +30,8 @@ public Phi3MLP(Phi3Config config) public Phi3MLP(int hiddenSize, int intermediateSize, string hiddenAct, ScalarType dtype) : base(nameof(Phi3MLP)) { - this.gate_up_proj = new GenAILinear(hiddenSize, 2 * intermediateSize, hasBias: false, dtype: dtype); - this.down_proj = new GenAILinear(intermediateSize, hiddenSize, hasBias: false, dtype: dtype); + this.gate_up_proj = new QuantizedLinear(hiddenSize, 2 * intermediateSize, hasBias: false, dtype: dtype); + this.down_proj = new QuantizedLinear(intermediateSize, hiddenSize, hasBias: false, dtype: dtype); this.RegisterComponents(); this.activation_fn = Utils.GetActivation(hiddenAct); } diff --git a/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj b/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj new file mode 100644 index 0000000000..7e3d9a9943 --- /dev/null +++ b/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj @@ -0,0 +1,31 @@ + + + + net8.0 + enable + enable + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/test/Microsoft.ML.GenAI.Core.Tests/QuantizedLinearTests.cs b/test/Microsoft.ML.GenAI.Core.Tests/QuantizedLinearTests.cs new file mode 100644 index 0000000000..e9687454f4 --- /dev/null +++ b/test/Microsoft.ML.GenAI.Core.Tests/QuantizedLinearTests.cs @@ -0,0 +1,134 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using FluentAssertions; +using Microsoft.ML.GenAI.Core.Extension; +using Microsoft.ML.TestFramework; +using TorchSharp; +using Xunit; +using Xunit.Abstractions; +using static TorchSharp.torch; + +namespace Microsoft.ML.GenAI.Core.Tests; + +public class QuantizedLinearTests : BaseTestClass +{ + public QuantizedLinearTests(ITestOutputHelper output) : base(output) + { + } + + [Fact] + public void Int4QuantizeSizeTests() + { + // meta is critical for the test + // as the size of the model to test is 372 GB + // and can't be loaded in real device like cpu or cuda + var device = "meta"; + var model = new QuantizedLinear(100000, 100, device: device); + + var sizeInBytes = model.GetSizeInBytes(); + + var sizeInGigaBytes = sizeInBytes / 1024 / 1024; + sizeInGigaBytes.Should().Be(38); + + // to int4 + model.Int4(); + var sizeInBytesAfterInt8 = model.GetSizeInBytes(); + var sizeInGigaBytesAfterInt8 = sizeInBytesAfterInt8 / 1024 / 1024; + sizeInGigaBytesAfterInt8.Should().Be(4); // 38 // 8 = 4 + } + + [Fact] + public void Int8QuantizeSizeTests() + { + // meta is critical for the test + // as the size of the model to test is 372 GB + // and can't be loaded in real device like cpu or cuda + var device = "meta"; + var model = new QuantizedLinear(100000, 100, device: device); + + var sizeInBytes = model.GetSizeInBytes(); + + var sizeInGigaBytes = sizeInBytes / 1024 / 1024; + sizeInGigaBytes.Should().Be(38); + + // to int8 + model.Int8(); + var sizeInBytesAfterInt8 = model.GetSizeInBytes(); + var sizeInGigaBytesAfterInt8 = sizeInBytesAfterInt8 / 1024 / 1024; + sizeInGigaBytesAfterInt8.Should().Be(9); // 38 // 4 = 9 + } + + [Fact] + public void Int4QuantizeForwardTest() + { + var device = "cpu"; + var model = new QuantizedLinear(123, 10, device: device); + + // set both weight and bias to rand int8 values + // and compare the result before and after ToInt8 + var input = torch.ones([10, 2200, 123], device: device); + var weight = torch.ones([10, 123], device: device, dtype: ScalarType.Int64) * -1; + var bias = torch.ones([10], device: device) * 2; + + var weightStr = weight.Peek("weight").ToString(); + + weight = (weight + 8).view(-1).to(torch.uint8); + var weightPlaceHolderDim = (int)weight.size(0); + weightPlaceHolderDim = weightPlaceHolderDim / 2 + weightPlaceHolderDim % 2; + var weightPlaceHolder = weight[..weightPlaceHolderDim]; + weightPlaceHolder = weightPlaceHolder * 16 + weight[weightPlaceHolderDim..]; + + var high4Bit = weightPlaceHolder / 16; + var low4Bit = weightPlaceHolder % 16; + weight = torch.cat(new Tensor[] { high4Bit, low4Bit }).view(10, 123); + weight = weight.to(torch.int64); + weight -= 8; + weight.Peek("weight").Should().Be(weightStr); + + model.load_state_dict(new Dictionary + { + ["weight"] = weight, + ["bias"] = bias + }); + + var resultBeforeInt4 = model.forward(input); + + model.ToInt4QuantizeModule(); + + var resultAfterInt4 = model.forward(input); + + // compare the result + resultBeforeInt4.Peek("result").Should().Be(resultAfterInt4.Peek("result")); + } + + [Fact] + public void Int8QuantizeForwardTest() + { + var device = "cpu"; + var model = new QuantizedLinear(123, 10, device: device); + + // set both weight and bias to rand int8 values + // and compare the result before and after ToInt8 + var input = torch.ones([10, 2200, 123], device: device); + var weight = torch.ones([10, 123], device: device) * -1; + var bias = torch.ones([10], device: device) * 2; + + model.load_state_dict(new Dictionary + { + ["weight"] = weight, + ["bias"] = bias + }); + + var resultBeforeInt8 = model.forward(input); + + model.ToInt8QuantizeModule(); + + var resultAfterInt8 = model.forward(input); + + resultBeforeInt8.Peek("result").Should().Be("result: sum: 312.6933 dType: Float32 shape: [10,2200,10]"); + resultAfterInt8.Peek("result").Should().Be("result: sum: 312.6933 dType: Float32 shape: [10,2200,10]"); + } +} diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Mini4KInt4QuantizeShapeTest.approved.txt b/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Mini4KInt4QuantizeShapeTest.approved.txt new file mode 100644 index 0000000000..1855e6396e --- /dev/null +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Mini4KInt4QuantizeShapeTest.approved.txt @@ -0,0 +1,451 @@ +0: lm_head.weight shape: [32064, 3072] +1: model.embed_tokens.weight shape: [32064, 3072] +2: model.layers.0.input_layernorm.weight shape: [3072] +3: model.layers.0.mlp.down_proj.4bit_weight shape: [12582912] +4: model.layers.0.mlp.down_proj.scale shape: [3072] +5: model.layers.0.mlp.down_proj.zeroPoint shape: [1536] +6: model.layers.0.mlp.gate_up_proj.4bit_weight shape: [25165824] +7: model.layers.0.mlp.gate_up_proj.scale shape: [16384] +8: model.layers.0.mlp.gate_up_proj.zeroPoint shape: [8192] +9: model.layers.0.post_attention_layernorm.weight shape: [3072] +10: model.layers.0.self_attn.o_proj.4bit_weight shape: [4718592] +11: model.layers.0.self_attn.o_proj.scale shape: [3072] +12: model.layers.0.self_attn.o_proj.zeroPoint shape: [1536] +13: model.layers.0.self_attn.qkv_proj.4bit_weight shape: [14155776] +14: model.layers.0.self_attn.qkv_proj.scale shape: [9216] +15: model.layers.0.self_attn.qkv_proj.zeroPoint shape: [4608] +16: model.layers.1.input_layernorm.weight shape: [3072] +17: model.layers.1.mlp.down_proj.4bit_weight shape: [12582912] +18: model.layers.1.mlp.down_proj.scale shape: [3072] +19: model.layers.1.mlp.down_proj.zeroPoint shape: [1536] +20: model.layers.1.mlp.gate_up_proj.4bit_weight shape: [25165824] +21: model.layers.1.mlp.gate_up_proj.scale shape: [16384] +22: model.layers.1.mlp.gate_up_proj.zeroPoint shape: [8192] +23: model.layers.1.post_attention_layernorm.weight shape: [3072] +24: model.layers.1.self_attn.o_proj.4bit_weight shape: [4718592] +25: model.layers.1.self_attn.o_proj.scale shape: [3072] +26: model.layers.1.self_attn.o_proj.zeroPoint shape: [1536] +27: model.layers.1.self_attn.qkv_proj.4bit_weight shape: [14155776] +28: model.layers.1.self_attn.qkv_proj.scale shape: [9216] +29: model.layers.1.self_attn.qkv_proj.zeroPoint shape: [4608] +30: model.layers.10.input_layernorm.weight shape: [3072] +31: model.layers.10.mlp.down_proj.4bit_weight shape: [12582912] +32: model.layers.10.mlp.down_proj.scale shape: [3072] +33: model.layers.10.mlp.down_proj.zeroPoint shape: [1536] +34: model.layers.10.mlp.gate_up_proj.4bit_weight shape: [25165824] +35: model.layers.10.mlp.gate_up_proj.scale shape: [16384] +36: model.layers.10.mlp.gate_up_proj.zeroPoint shape: [8192] +37: model.layers.10.post_attention_layernorm.weight shape: [3072] +38: model.layers.10.self_attn.o_proj.4bit_weight shape: [4718592] +39: model.layers.10.self_attn.o_proj.scale shape: [3072] +40: model.layers.10.self_attn.o_proj.zeroPoint shape: [1536] +41: model.layers.10.self_attn.qkv_proj.4bit_weight shape: [14155776] +42: model.layers.10.self_attn.qkv_proj.scale shape: [9216] +43: model.layers.10.self_attn.qkv_proj.zeroPoint shape: [4608] +44: model.layers.11.input_layernorm.weight shape: [3072] +45: model.layers.11.mlp.down_proj.4bit_weight shape: [12582912] +46: model.layers.11.mlp.down_proj.scale shape: [3072] +47: model.layers.11.mlp.down_proj.zeroPoint shape: [1536] +48: model.layers.11.mlp.gate_up_proj.4bit_weight shape: [25165824] +49: model.layers.11.mlp.gate_up_proj.scale shape: [16384] +50: model.layers.11.mlp.gate_up_proj.zeroPoint shape: [8192] +51: model.layers.11.post_attention_layernorm.weight shape: [3072] +52: model.layers.11.self_attn.o_proj.4bit_weight shape: [4718592] +53: model.layers.11.self_attn.o_proj.scale shape: [3072] +54: model.layers.11.self_attn.o_proj.zeroPoint shape: [1536] +55: model.layers.11.self_attn.qkv_proj.4bit_weight shape: [14155776] +56: model.layers.11.self_attn.qkv_proj.scale shape: [9216] +57: model.layers.11.self_attn.qkv_proj.zeroPoint shape: [4608] +58: model.layers.12.input_layernorm.weight shape: [3072] +59: model.layers.12.mlp.down_proj.4bit_weight shape: [12582912] +60: model.layers.12.mlp.down_proj.scale shape: [3072] +61: model.layers.12.mlp.down_proj.zeroPoint shape: [1536] +62: model.layers.12.mlp.gate_up_proj.4bit_weight shape: [25165824] +63: model.layers.12.mlp.gate_up_proj.scale shape: [16384] +64: model.layers.12.mlp.gate_up_proj.zeroPoint shape: [8192] +65: model.layers.12.post_attention_layernorm.weight shape: [3072] +66: model.layers.12.self_attn.o_proj.4bit_weight shape: [4718592] +67: model.layers.12.self_attn.o_proj.scale shape: [3072] +68: model.layers.12.self_attn.o_proj.zeroPoint shape: [1536] +69: model.layers.12.self_attn.qkv_proj.4bit_weight shape: [14155776] +70: model.layers.12.self_attn.qkv_proj.scale shape: [9216] +71: model.layers.12.self_attn.qkv_proj.zeroPoint shape: [4608] +72: model.layers.13.input_layernorm.weight shape: [3072] +73: model.layers.13.mlp.down_proj.4bit_weight shape: [12582912] +74: model.layers.13.mlp.down_proj.scale shape: [3072] +75: model.layers.13.mlp.down_proj.zeroPoint shape: [1536] +76: model.layers.13.mlp.gate_up_proj.4bit_weight shape: [25165824] +77: model.layers.13.mlp.gate_up_proj.scale shape: [16384] +78: model.layers.13.mlp.gate_up_proj.zeroPoint shape: [8192] +79: model.layers.13.post_attention_layernorm.weight shape: [3072] +80: model.layers.13.self_attn.o_proj.4bit_weight shape: [4718592] +81: model.layers.13.self_attn.o_proj.scale shape: [3072] +82: model.layers.13.self_attn.o_proj.zeroPoint shape: [1536] +83: model.layers.13.self_attn.qkv_proj.4bit_weight shape: [14155776] +84: model.layers.13.self_attn.qkv_proj.scale shape: [9216] +85: model.layers.13.self_attn.qkv_proj.zeroPoint shape: [4608] +86: model.layers.14.input_layernorm.weight shape: [3072] +87: model.layers.14.mlp.down_proj.4bit_weight shape: [12582912] +88: model.layers.14.mlp.down_proj.scale shape: [3072] +89: model.layers.14.mlp.down_proj.zeroPoint shape: [1536] +90: model.layers.14.mlp.gate_up_proj.4bit_weight shape: [25165824] +91: model.layers.14.mlp.gate_up_proj.scale shape: [16384] +92: model.layers.14.mlp.gate_up_proj.zeroPoint shape: [8192] +93: model.layers.14.post_attention_layernorm.weight shape: [3072] +94: model.layers.14.self_attn.o_proj.4bit_weight shape: [4718592] +95: model.layers.14.self_attn.o_proj.scale shape: [3072] +96: model.layers.14.self_attn.o_proj.zeroPoint shape: [1536] +97: model.layers.14.self_attn.qkv_proj.4bit_weight shape: [14155776] +98: model.layers.14.self_attn.qkv_proj.scale shape: [9216] +99: model.layers.14.self_attn.qkv_proj.zeroPoint shape: [4608] +100: model.layers.15.input_layernorm.weight shape: [3072] +101: model.layers.15.mlp.down_proj.4bit_weight shape: [12582912] +102: model.layers.15.mlp.down_proj.scale shape: [3072] +103: model.layers.15.mlp.down_proj.zeroPoint shape: [1536] +104: model.layers.15.mlp.gate_up_proj.4bit_weight shape: [25165824] +105: model.layers.15.mlp.gate_up_proj.scale shape: [16384] +106: model.layers.15.mlp.gate_up_proj.zeroPoint shape: [8192] +107: model.layers.15.post_attention_layernorm.weight shape: [3072] +108: model.layers.15.self_attn.o_proj.4bit_weight shape: [4718592] +109: model.layers.15.self_attn.o_proj.scale shape: [3072] +110: model.layers.15.self_attn.o_proj.zeroPoint shape: [1536] +111: model.layers.15.self_attn.qkv_proj.4bit_weight shape: [14155776] +112: model.layers.15.self_attn.qkv_proj.scale shape: [9216] +113: model.layers.15.self_attn.qkv_proj.zeroPoint shape: [4608] +114: model.layers.16.input_layernorm.weight shape: [3072] +115: model.layers.16.mlp.down_proj.4bit_weight shape: [12582912] +116: model.layers.16.mlp.down_proj.scale shape: [3072] +117: model.layers.16.mlp.down_proj.zeroPoint shape: [1536] +118: model.layers.16.mlp.gate_up_proj.4bit_weight shape: [25165824] +119: model.layers.16.mlp.gate_up_proj.scale shape: [16384] +120: model.layers.16.mlp.gate_up_proj.zeroPoint shape: [8192] +121: model.layers.16.post_attention_layernorm.weight shape: [3072] +122: model.layers.16.self_attn.o_proj.4bit_weight shape: [4718592] +123: model.layers.16.self_attn.o_proj.scale shape: [3072] +124: model.layers.16.self_attn.o_proj.zeroPoint shape: [1536] +125: model.layers.16.self_attn.qkv_proj.4bit_weight shape: [14155776] +126: model.layers.16.self_attn.qkv_proj.scale shape: [9216] +127: model.layers.16.self_attn.qkv_proj.zeroPoint shape: [4608] +128: model.layers.17.input_layernorm.weight shape: [3072] +129: model.layers.17.mlp.down_proj.4bit_weight shape: [12582912] +130: model.layers.17.mlp.down_proj.scale shape: [3072] +131: model.layers.17.mlp.down_proj.zeroPoint shape: [1536] +132: model.layers.17.mlp.gate_up_proj.4bit_weight shape: [25165824] +133: model.layers.17.mlp.gate_up_proj.scale shape: [16384] +134: model.layers.17.mlp.gate_up_proj.zeroPoint shape: [8192] +135: model.layers.17.post_attention_layernorm.weight shape: [3072] +136: model.layers.17.self_attn.o_proj.4bit_weight shape: [4718592] +137: model.layers.17.self_attn.o_proj.scale shape: [3072] +138: model.layers.17.self_attn.o_proj.zeroPoint shape: [1536] +139: model.layers.17.self_attn.qkv_proj.4bit_weight shape: [14155776] +140: model.layers.17.self_attn.qkv_proj.scale shape: [9216] +141: model.layers.17.self_attn.qkv_proj.zeroPoint shape: [4608] +142: model.layers.18.input_layernorm.weight shape: [3072] +143: model.layers.18.mlp.down_proj.4bit_weight shape: [12582912] +144: model.layers.18.mlp.down_proj.scale shape: [3072] +145: model.layers.18.mlp.down_proj.zeroPoint shape: [1536] +146: model.layers.18.mlp.gate_up_proj.4bit_weight shape: [25165824] +147: model.layers.18.mlp.gate_up_proj.scale shape: [16384] +148: model.layers.18.mlp.gate_up_proj.zeroPoint shape: [8192] +149: model.layers.18.post_attention_layernorm.weight shape: [3072] +150: model.layers.18.self_attn.o_proj.4bit_weight shape: [4718592] +151: model.layers.18.self_attn.o_proj.scale shape: [3072] +152: model.layers.18.self_attn.o_proj.zeroPoint shape: [1536] +153: model.layers.18.self_attn.qkv_proj.4bit_weight shape: [14155776] +154: model.layers.18.self_attn.qkv_proj.scale shape: [9216] +155: model.layers.18.self_attn.qkv_proj.zeroPoint shape: [4608] +156: model.layers.19.input_layernorm.weight shape: [3072] +157: model.layers.19.mlp.down_proj.4bit_weight shape: [12582912] +158: model.layers.19.mlp.down_proj.scale shape: [3072] +159: model.layers.19.mlp.down_proj.zeroPoint shape: [1536] +160: model.layers.19.mlp.gate_up_proj.4bit_weight shape: [25165824] +161: model.layers.19.mlp.gate_up_proj.scale shape: [16384] +162: model.layers.19.mlp.gate_up_proj.zeroPoint shape: [8192] +163: model.layers.19.post_attention_layernorm.weight shape: [3072] +164: model.layers.19.self_attn.o_proj.4bit_weight shape: [4718592] +165: model.layers.19.self_attn.o_proj.scale shape: [3072] +166: model.layers.19.self_attn.o_proj.zeroPoint shape: [1536] +167: model.layers.19.self_attn.qkv_proj.4bit_weight shape: [14155776] +168: model.layers.19.self_attn.qkv_proj.scale shape: [9216] +169: model.layers.19.self_attn.qkv_proj.zeroPoint shape: [4608] +170: model.layers.2.input_layernorm.weight shape: [3072] +171: model.layers.2.mlp.down_proj.4bit_weight shape: [12582912] +172: model.layers.2.mlp.down_proj.scale shape: [3072] +173: model.layers.2.mlp.down_proj.zeroPoint shape: [1536] +174: model.layers.2.mlp.gate_up_proj.4bit_weight shape: [25165824] +175: model.layers.2.mlp.gate_up_proj.scale shape: [16384] +176: model.layers.2.mlp.gate_up_proj.zeroPoint shape: [8192] +177: model.layers.2.post_attention_layernorm.weight shape: [3072] +178: model.layers.2.self_attn.o_proj.4bit_weight shape: [4718592] +179: model.layers.2.self_attn.o_proj.scale shape: [3072] +180: model.layers.2.self_attn.o_proj.zeroPoint shape: [1536] +181: model.layers.2.self_attn.qkv_proj.4bit_weight shape: [14155776] +182: model.layers.2.self_attn.qkv_proj.scale shape: [9216] +183: model.layers.2.self_attn.qkv_proj.zeroPoint shape: [4608] +184: model.layers.20.input_layernorm.weight shape: [3072] +185: model.layers.20.mlp.down_proj.4bit_weight shape: [12582912] +186: model.layers.20.mlp.down_proj.scale shape: [3072] +187: model.layers.20.mlp.down_proj.zeroPoint shape: [1536] +188: model.layers.20.mlp.gate_up_proj.4bit_weight shape: [25165824] +189: model.layers.20.mlp.gate_up_proj.scale shape: [16384] +190: model.layers.20.mlp.gate_up_proj.zeroPoint shape: [8192] +191: model.layers.20.post_attention_layernorm.weight shape: [3072] +192: model.layers.20.self_attn.o_proj.4bit_weight shape: [4718592] +193: model.layers.20.self_attn.o_proj.scale shape: [3072] +194: model.layers.20.self_attn.o_proj.zeroPoint shape: [1536] +195: model.layers.20.self_attn.qkv_proj.4bit_weight shape: [14155776] +196: model.layers.20.self_attn.qkv_proj.scale shape: [9216] +197: model.layers.20.self_attn.qkv_proj.zeroPoint shape: [4608] +198: model.layers.21.input_layernorm.weight shape: [3072] +199: model.layers.21.mlp.down_proj.4bit_weight shape: [12582912] +200: model.layers.21.mlp.down_proj.scale shape: [3072] +201: model.layers.21.mlp.down_proj.zeroPoint shape: [1536] +202: model.layers.21.mlp.gate_up_proj.4bit_weight shape: [25165824] +203: model.layers.21.mlp.gate_up_proj.scale shape: [16384] +204: model.layers.21.mlp.gate_up_proj.zeroPoint shape: [8192] +205: model.layers.21.post_attention_layernorm.weight shape: [3072] +206: model.layers.21.self_attn.o_proj.4bit_weight shape: [4718592] +207: model.layers.21.self_attn.o_proj.scale shape: [3072] +208: model.layers.21.self_attn.o_proj.zeroPoint shape: [1536] +209: model.layers.21.self_attn.qkv_proj.4bit_weight shape: [14155776] +210: model.layers.21.self_attn.qkv_proj.scale shape: [9216] +211: model.layers.21.self_attn.qkv_proj.zeroPoint shape: [4608] +212: model.layers.22.input_layernorm.weight shape: [3072] +213: model.layers.22.mlp.down_proj.4bit_weight shape: [12582912] +214: model.layers.22.mlp.down_proj.scale shape: [3072] +215: model.layers.22.mlp.down_proj.zeroPoint shape: [1536] +216: model.layers.22.mlp.gate_up_proj.4bit_weight shape: [25165824] +217: model.layers.22.mlp.gate_up_proj.scale shape: [16384] +218: model.layers.22.mlp.gate_up_proj.zeroPoint shape: [8192] +219: model.layers.22.post_attention_layernorm.weight shape: [3072] +220: model.layers.22.self_attn.o_proj.4bit_weight shape: [4718592] +221: model.layers.22.self_attn.o_proj.scale shape: [3072] +222: model.layers.22.self_attn.o_proj.zeroPoint shape: [1536] +223: model.layers.22.self_attn.qkv_proj.4bit_weight shape: [14155776] +224: model.layers.22.self_attn.qkv_proj.scale shape: [9216] +225: model.layers.22.self_attn.qkv_proj.zeroPoint shape: [4608] +226: model.layers.23.input_layernorm.weight shape: [3072] +227: model.layers.23.mlp.down_proj.4bit_weight shape: [12582912] +228: model.layers.23.mlp.down_proj.scale shape: [3072] +229: model.layers.23.mlp.down_proj.zeroPoint shape: [1536] +230: model.layers.23.mlp.gate_up_proj.4bit_weight shape: [25165824] +231: model.layers.23.mlp.gate_up_proj.scale shape: [16384] +232: model.layers.23.mlp.gate_up_proj.zeroPoint shape: [8192] +233: model.layers.23.post_attention_layernorm.weight shape: [3072] +234: model.layers.23.self_attn.o_proj.4bit_weight shape: [4718592] +235: model.layers.23.self_attn.o_proj.scale shape: [3072] +236: model.layers.23.self_attn.o_proj.zeroPoint shape: [1536] +237: model.layers.23.self_attn.qkv_proj.4bit_weight shape: [14155776] +238: model.layers.23.self_attn.qkv_proj.scale shape: [9216] +239: model.layers.23.self_attn.qkv_proj.zeroPoint shape: [4608] +240: model.layers.24.input_layernorm.weight shape: [3072] +241: model.layers.24.mlp.down_proj.4bit_weight shape: [12582912] +242: model.layers.24.mlp.down_proj.scale shape: [3072] +243: model.layers.24.mlp.down_proj.zeroPoint shape: [1536] +244: model.layers.24.mlp.gate_up_proj.4bit_weight shape: [25165824] +245: model.layers.24.mlp.gate_up_proj.scale shape: [16384] +246: model.layers.24.mlp.gate_up_proj.zeroPoint shape: [8192] +247: model.layers.24.post_attention_layernorm.weight shape: [3072] +248: model.layers.24.self_attn.o_proj.4bit_weight shape: [4718592] +249: model.layers.24.self_attn.o_proj.scale shape: [3072] +250: model.layers.24.self_attn.o_proj.zeroPoint shape: [1536] +251: model.layers.24.self_attn.qkv_proj.4bit_weight shape: [14155776] +252: model.layers.24.self_attn.qkv_proj.scale shape: [9216] +253: model.layers.24.self_attn.qkv_proj.zeroPoint shape: [4608] +254: model.layers.25.input_layernorm.weight shape: [3072] +255: model.layers.25.mlp.down_proj.4bit_weight shape: [12582912] +256: model.layers.25.mlp.down_proj.scale shape: [3072] +257: model.layers.25.mlp.down_proj.zeroPoint shape: [1536] +258: model.layers.25.mlp.gate_up_proj.4bit_weight shape: [25165824] +259: model.layers.25.mlp.gate_up_proj.scale shape: [16384] +260: model.layers.25.mlp.gate_up_proj.zeroPoint shape: [8192] +261: model.layers.25.post_attention_layernorm.weight shape: [3072] +262: model.layers.25.self_attn.o_proj.4bit_weight shape: [4718592] +263: model.layers.25.self_attn.o_proj.scale shape: [3072] +264: model.layers.25.self_attn.o_proj.zeroPoint shape: [1536] +265: model.layers.25.self_attn.qkv_proj.4bit_weight shape: [14155776] +266: model.layers.25.self_attn.qkv_proj.scale shape: [9216] +267: model.layers.25.self_attn.qkv_proj.zeroPoint shape: [4608] +268: model.layers.26.input_layernorm.weight shape: [3072] +269: model.layers.26.mlp.down_proj.4bit_weight shape: [12582912] +270: model.layers.26.mlp.down_proj.scale shape: [3072] +271: model.layers.26.mlp.down_proj.zeroPoint shape: [1536] +272: model.layers.26.mlp.gate_up_proj.4bit_weight shape: [25165824] +273: model.layers.26.mlp.gate_up_proj.scale shape: [16384] +274: model.layers.26.mlp.gate_up_proj.zeroPoint shape: [8192] +275: model.layers.26.post_attention_layernorm.weight shape: [3072] +276: model.layers.26.self_attn.o_proj.4bit_weight shape: [4718592] +277: model.layers.26.self_attn.o_proj.scale shape: [3072] +278: model.layers.26.self_attn.o_proj.zeroPoint shape: [1536] +279: model.layers.26.self_attn.qkv_proj.4bit_weight shape: [14155776] +280: model.layers.26.self_attn.qkv_proj.scale shape: [9216] +281: model.layers.26.self_attn.qkv_proj.zeroPoint shape: [4608] +282: model.layers.27.input_layernorm.weight shape: [3072] +283: model.layers.27.mlp.down_proj.4bit_weight shape: [12582912] +284: model.layers.27.mlp.down_proj.scale shape: [3072] +285: model.layers.27.mlp.down_proj.zeroPoint shape: [1536] +286: model.layers.27.mlp.gate_up_proj.4bit_weight shape: [25165824] +287: model.layers.27.mlp.gate_up_proj.scale shape: [16384] +288: model.layers.27.mlp.gate_up_proj.zeroPoint shape: [8192] +289: model.layers.27.post_attention_layernorm.weight shape: [3072] +290: model.layers.27.self_attn.o_proj.4bit_weight shape: [4718592] +291: model.layers.27.self_attn.o_proj.scale shape: [3072] +292: model.layers.27.self_attn.o_proj.zeroPoint shape: [1536] +293: model.layers.27.self_attn.qkv_proj.4bit_weight shape: [14155776] +294: model.layers.27.self_attn.qkv_proj.scale shape: [9216] +295: model.layers.27.self_attn.qkv_proj.zeroPoint shape: [4608] +296: model.layers.28.input_layernorm.weight shape: [3072] +297: model.layers.28.mlp.down_proj.4bit_weight shape: [12582912] +298: model.layers.28.mlp.down_proj.scale shape: [3072] +299: model.layers.28.mlp.down_proj.zeroPoint shape: [1536] +300: model.layers.28.mlp.gate_up_proj.4bit_weight shape: [25165824] +301: model.layers.28.mlp.gate_up_proj.scale shape: [16384] +302: model.layers.28.mlp.gate_up_proj.zeroPoint shape: [8192] +303: model.layers.28.post_attention_layernorm.weight shape: [3072] +304: model.layers.28.self_attn.o_proj.4bit_weight shape: [4718592] +305: model.layers.28.self_attn.o_proj.scale shape: [3072] +306: model.layers.28.self_attn.o_proj.zeroPoint shape: [1536] +307: model.layers.28.self_attn.qkv_proj.4bit_weight shape: [14155776] +308: model.layers.28.self_attn.qkv_proj.scale shape: [9216] +309: model.layers.28.self_attn.qkv_proj.zeroPoint shape: [4608] +310: model.layers.29.input_layernorm.weight shape: [3072] +311: model.layers.29.mlp.down_proj.4bit_weight shape: [12582912] +312: model.layers.29.mlp.down_proj.scale shape: [3072] +313: model.layers.29.mlp.down_proj.zeroPoint shape: [1536] +314: model.layers.29.mlp.gate_up_proj.4bit_weight shape: [25165824] +315: model.layers.29.mlp.gate_up_proj.scale shape: [16384] +316: model.layers.29.mlp.gate_up_proj.zeroPoint shape: [8192] +317: model.layers.29.post_attention_layernorm.weight shape: [3072] +318: model.layers.29.self_attn.o_proj.4bit_weight shape: [4718592] +319: model.layers.29.self_attn.o_proj.scale shape: [3072] +320: model.layers.29.self_attn.o_proj.zeroPoint shape: [1536] +321: model.layers.29.self_attn.qkv_proj.4bit_weight shape: [14155776] +322: model.layers.29.self_attn.qkv_proj.scale shape: [9216] +323: model.layers.29.self_attn.qkv_proj.zeroPoint shape: [4608] +324: model.layers.3.input_layernorm.weight shape: [3072] +325: model.layers.3.mlp.down_proj.4bit_weight shape: [12582912] +326: model.layers.3.mlp.down_proj.scale shape: [3072] +327: model.layers.3.mlp.down_proj.zeroPoint shape: [1536] +328: model.layers.3.mlp.gate_up_proj.4bit_weight shape: [25165824] +329: model.layers.3.mlp.gate_up_proj.scale shape: [16384] +330: model.layers.3.mlp.gate_up_proj.zeroPoint shape: [8192] +331: model.layers.3.post_attention_layernorm.weight shape: [3072] +332: model.layers.3.self_attn.o_proj.4bit_weight shape: [4718592] +333: model.layers.3.self_attn.o_proj.scale shape: [3072] +334: model.layers.3.self_attn.o_proj.zeroPoint shape: [1536] +335: model.layers.3.self_attn.qkv_proj.4bit_weight shape: [14155776] +336: model.layers.3.self_attn.qkv_proj.scale shape: [9216] +337: model.layers.3.self_attn.qkv_proj.zeroPoint shape: [4608] +338: model.layers.30.input_layernorm.weight shape: [3072] +339: model.layers.30.mlp.down_proj.4bit_weight shape: [12582912] +340: model.layers.30.mlp.down_proj.scale shape: [3072] +341: model.layers.30.mlp.down_proj.zeroPoint shape: [1536] +342: model.layers.30.mlp.gate_up_proj.4bit_weight shape: [25165824] +343: model.layers.30.mlp.gate_up_proj.scale shape: [16384] +344: model.layers.30.mlp.gate_up_proj.zeroPoint shape: [8192] +345: model.layers.30.post_attention_layernorm.weight shape: [3072] +346: model.layers.30.self_attn.o_proj.4bit_weight shape: [4718592] +347: model.layers.30.self_attn.o_proj.scale shape: [3072] +348: model.layers.30.self_attn.o_proj.zeroPoint shape: [1536] +349: model.layers.30.self_attn.qkv_proj.4bit_weight shape: [14155776] +350: model.layers.30.self_attn.qkv_proj.scale shape: [9216] +351: model.layers.30.self_attn.qkv_proj.zeroPoint shape: [4608] +352: model.layers.31.input_layernorm.weight shape: [3072] +353: model.layers.31.mlp.down_proj.4bit_weight shape: [12582912] +354: model.layers.31.mlp.down_proj.scale shape: [3072] +355: model.layers.31.mlp.down_proj.zeroPoint shape: [1536] +356: model.layers.31.mlp.gate_up_proj.4bit_weight shape: [25165824] +357: model.layers.31.mlp.gate_up_proj.scale shape: [16384] +358: model.layers.31.mlp.gate_up_proj.zeroPoint shape: [8192] +359: model.layers.31.post_attention_layernorm.weight shape: [3072] +360: model.layers.31.self_attn.o_proj.4bit_weight shape: [4718592] +361: model.layers.31.self_attn.o_proj.scale shape: [3072] +362: model.layers.31.self_attn.o_proj.zeroPoint shape: [1536] +363: model.layers.31.self_attn.qkv_proj.4bit_weight shape: [14155776] +364: model.layers.31.self_attn.qkv_proj.scale shape: [9216] +365: model.layers.31.self_attn.qkv_proj.zeroPoint shape: [4608] +366: model.layers.4.input_layernorm.weight shape: [3072] +367: model.layers.4.mlp.down_proj.4bit_weight shape: [12582912] +368: model.layers.4.mlp.down_proj.scale shape: [3072] +369: model.layers.4.mlp.down_proj.zeroPoint shape: [1536] +370: model.layers.4.mlp.gate_up_proj.4bit_weight shape: [25165824] +371: model.layers.4.mlp.gate_up_proj.scale shape: [16384] +372: model.layers.4.mlp.gate_up_proj.zeroPoint shape: [8192] +373: model.layers.4.post_attention_layernorm.weight shape: [3072] +374: model.layers.4.self_attn.o_proj.4bit_weight shape: [4718592] +375: model.layers.4.self_attn.o_proj.scale shape: [3072] +376: model.layers.4.self_attn.o_proj.zeroPoint shape: [1536] +377: model.layers.4.self_attn.qkv_proj.4bit_weight shape: [14155776] +378: model.layers.4.self_attn.qkv_proj.scale shape: [9216] +379: model.layers.4.self_attn.qkv_proj.zeroPoint shape: [4608] +380: model.layers.5.input_layernorm.weight shape: [3072] +381: model.layers.5.mlp.down_proj.4bit_weight shape: [12582912] +382: model.layers.5.mlp.down_proj.scale shape: [3072] +383: model.layers.5.mlp.down_proj.zeroPoint shape: [1536] +384: model.layers.5.mlp.gate_up_proj.4bit_weight shape: [25165824] +385: model.layers.5.mlp.gate_up_proj.scale shape: [16384] +386: model.layers.5.mlp.gate_up_proj.zeroPoint shape: [8192] +387: model.layers.5.post_attention_layernorm.weight shape: [3072] +388: model.layers.5.self_attn.o_proj.4bit_weight shape: [4718592] +389: model.layers.5.self_attn.o_proj.scale shape: [3072] +390: model.layers.5.self_attn.o_proj.zeroPoint shape: [1536] +391: model.layers.5.self_attn.qkv_proj.4bit_weight shape: [14155776] +392: model.layers.5.self_attn.qkv_proj.scale shape: [9216] +393: model.layers.5.self_attn.qkv_proj.zeroPoint shape: [4608] +394: model.layers.6.input_layernorm.weight shape: [3072] +395: model.layers.6.mlp.down_proj.4bit_weight shape: [12582912] +396: model.layers.6.mlp.down_proj.scale shape: [3072] +397: model.layers.6.mlp.down_proj.zeroPoint shape: [1536] +398: model.layers.6.mlp.gate_up_proj.4bit_weight shape: [25165824] +399: model.layers.6.mlp.gate_up_proj.scale shape: [16384] +400: model.layers.6.mlp.gate_up_proj.zeroPoint shape: [8192] +401: model.layers.6.post_attention_layernorm.weight shape: [3072] +402: model.layers.6.self_attn.o_proj.4bit_weight shape: [4718592] +403: model.layers.6.self_attn.o_proj.scale shape: [3072] +404: model.layers.6.self_attn.o_proj.zeroPoint shape: [1536] +405: model.layers.6.self_attn.qkv_proj.4bit_weight shape: [14155776] +406: model.layers.6.self_attn.qkv_proj.scale shape: [9216] +407: model.layers.6.self_attn.qkv_proj.zeroPoint shape: [4608] +408: model.layers.7.input_layernorm.weight shape: [3072] +409: model.layers.7.mlp.down_proj.4bit_weight shape: [12582912] +410: model.layers.7.mlp.down_proj.scale shape: [3072] +411: model.layers.7.mlp.down_proj.zeroPoint shape: [1536] +412: model.layers.7.mlp.gate_up_proj.4bit_weight shape: [25165824] +413: model.layers.7.mlp.gate_up_proj.scale shape: [16384] +414: model.layers.7.mlp.gate_up_proj.zeroPoint shape: [8192] +415: model.layers.7.post_attention_layernorm.weight shape: [3072] +416: model.layers.7.self_attn.o_proj.4bit_weight shape: [4718592] +417: model.layers.7.self_attn.o_proj.scale shape: [3072] +418: model.layers.7.self_attn.o_proj.zeroPoint shape: [1536] +419: model.layers.7.self_attn.qkv_proj.4bit_weight shape: [14155776] +420: model.layers.7.self_attn.qkv_proj.scale shape: [9216] +421: model.layers.7.self_attn.qkv_proj.zeroPoint shape: [4608] +422: model.layers.8.input_layernorm.weight shape: [3072] +423: model.layers.8.mlp.down_proj.4bit_weight shape: [12582912] +424: model.layers.8.mlp.down_proj.scale shape: [3072] +425: model.layers.8.mlp.down_proj.zeroPoint shape: [1536] +426: model.layers.8.mlp.gate_up_proj.4bit_weight shape: [25165824] +427: model.layers.8.mlp.gate_up_proj.scale shape: [16384] +428: model.layers.8.mlp.gate_up_proj.zeroPoint shape: [8192] +429: model.layers.8.post_attention_layernorm.weight shape: [3072] +430: model.layers.8.self_attn.o_proj.4bit_weight shape: [4718592] +431: model.layers.8.self_attn.o_proj.scale shape: [3072] +432: model.layers.8.self_attn.o_proj.zeroPoint shape: [1536] +433: model.layers.8.self_attn.qkv_proj.4bit_weight shape: [14155776] +434: model.layers.8.self_attn.qkv_proj.scale shape: [9216] +435: model.layers.8.self_attn.qkv_proj.zeroPoint shape: [4608] +436: model.layers.9.input_layernorm.weight shape: [3072] +437: model.layers.9.mlp.down_proj.4bit_weight shape: [12582912] +438: model.layers.9.mlp.down_proj.scale shape: [3072] +439: model.layers.9.mlp.down_proj.zeroPoint shape: [1536] +440: model.layers.9.mlp.gate_up_proj.4bit_weight shape: [25165824] +441: model.layers.9.mlp.gate_up_proj.scale shape: [16384] +442: model.layers.9.mlp.gate_up_proj.zeroPoint shape: [8192] +443: model.layers.9.post_attention_layernorm.weight shape: [3072] +444: model.layers.9.self_attn.o_proj.4bit_weight shape: [4718592] +445: model.layers.9.self_attn.o_proj.scale shape: [3072] +446: model.layers.9.self_attn.o_proj.zeroPoint shape: [1536] +447: model.layers.9.self_attn.qkv_proj.4bit_weight shape: [14155776] +448: model.layers.9.self_attn.qkv_proj.scale shape: [9216] +449: model.layers.9.self_attn.qkv_proj.zeroPoint shape: [4608] +450: model.norm.weight shape: [3072] diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Mini4KInt8QuantizeShapeTest.approved.txt b/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Mini4KInt8QuantizeShapeTest.approved.txt new file mode 100644 index 0000000000..d3ab1d8010 --- /dev/null +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Mini4KInt8QuantizeShapeTest.approved.txt @@ -0,0 +1,451 @@ +0: lm_head.weight shape: [32064, 3072] +1: model.embed_tokens.weight shape: [32064, 3072] +2: model.layers.0.input_layernorm.weight shape: [3072] +3: model.layers.0.mlp.down_proj.8bit_weight shape: [3072, 8192] +4: model.layers.0.mlp.down_proj.scale shape: [3072] +5: model.layers.0.mlp.down_proj.zeroPoint shape: [3072] +6: model.layers.0.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +7: model.layers.0.mlp.gate_up_proj.scale shape: [16384] +8: model.layers.0.mlp.gate_up_proj.zeroPoint shape: [16384] +9: model.layers.0.post_attention_layernorm.weight shape: [3072] +10: model.layers.0.self_attn.o_proj.8bit_weight shape: [3072, 3072] +11: model.layers.0.self_attn.o_proj.scale shape: [3072] +12: model.layers.0.self_attn.o_proj.zeroPoint shape: [3072] +13: model.layers.0.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +14: model.layers.0.self_attn.qkv_proj.scale shape: [9216] +15: model.layers.0.self_attn.qkv_proj.zeroPoint shape: [9216] +16: model.layers.1.input_layernorm.weight shape: [3072] +17: model.layers.1.mlp.down_proj.8bit_weight shape: [3072, 8192] +18: model.layers.1.mlp.down_proj.scale shape: [3072] +19: model.layers.1.mlp.down_proj.zeroPoint shape: [3072] +20: model.layers.1.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +21: model.layers.1.mlp.gate_up_proj.scale shape: [16384] +22: model.layers.1.mlp.gate_up_proj.zeroPoint shape: [16384] +23: model.layers.1.post_attention_layernorm.weight shape: [3072] +24: model.layers.1.self_attn.o_proj.8bit_weight shape: [3072, 3072] +25: model.layers.1.self_attn.o_proj.scale shape: [3072] +26: model.layers.1.self_attn.o_proj.zeroPoint shape: [3072] +27: model.layers.1.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +28: model.layers.1.self_attn.qkv_proj.scale shape: [9216] +29: model.layers.1.self_attn.qkv_proj.zeroPoint shape: [9216] +30: model.layers.10.input_layernorm.weight shape: [3072] +31: model.layers.10.mlp.down_proj.8bit_weight shape: [3072, 8192] +32: model.layers.10.mlp.down_proj.scale shape: [3072] +33: model.layers.10.mlp.down_proj.zeroPoint shape: [3072] +34: model.layers.10.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +35: model.layers.10.mlp.gate_up_proj.scale shape: [16384] +36: model.layers.10.mlp.gate_up_proj.zeroPoint shape: [16384] +37: model.layers.10.post_attention_layernorm.weight shape: [3072] +38: model.layers.10.self_attn.o_proj.8bit_weight shape: [3072, 3072] +39: model.layers.10.self_attn.o_proj.scale shape: [3072] +40: model.layers.10.self_attn.o_proj.zeroPoint shape: [3072] +41: model.layers.10.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +42: model.layers.10.self_attn.qkv_proj.scale shape: [9216] +43: model.layers.10.self_attn.qkv_proj.zeroPoint shape: [9216] +44: model.layers.11.input_layernorm.weight shape: [3072] +45: model.layers.11.mlp.down_proj.8bit_weight shape: [3072, 8192] +46: model.layers.11.mlp.down_proj.scale shape: [3072] +47: model.layers.11.mlp.down_proj.zeroPoint shape: [3072] +48: model.layers.11.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +49: model.layers.11.mlp.gate_up_proj.scale shape: [16384] +50: model.layers.11.mlp.gate_up_proj.zeroPoint shape: [16384] +51: model.layers.11.post_attention_layernorm.weight shape: [3072] +52: model.layers.11.self_attn.o_proj.8bit_weight shape: [3072, 3072] +53: model.layers.11.self_attn.o_proj.scale shape: [3072] +54: model.layers.11.self_attn.o_proj.zeroPoint shape: [3072] +55: model.layers.11.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +56: model.layers.11.self_attn.qkv_proj.scale shape: [9216] +57: model.layers.11.self_attn.qkv_proj.zeroPoint shape: [9216] +58: model.layers.12.input_layernorm.weight shape: [3072] +59: model.layers.12.mlp.down_proj.8bit_weight shape: [3072, 8192] +60: model.layers.12.mlp.down_proj.scale shape: [3072] +61: model.layers.12.mlp.down_proj.zeroPoint shape: [3072] +62: model.layers.12.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +63: model.layers.12.mlp.gate_up_proj.scale shape: [16384] +64: model.layers.12.mlp.gate_up_proj.zeroPoint shape: [16384] +65: model.layers.12.post_attention_layernorm.weight shape: [3072] +66: model.layers.12.self_attn.o_proj.8bit_weight shape: [3072, 3072] +67: model.layers.12.self_attn.o_proj.scale shape: [3072] +68: model.layers.12.self_attn.o_proj.zeroPoint shape: [3072] +69: model.layers.12.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +70: model.layers.12.self_attn.qkv_proj.scale shape: [9216] +71: model.layers.12.self_attn.qkv_proj.zeroPoint shape: [9216] +72: model.layers.13.input_layernorm.weight shape: [3072] +73: model.layers.13.mlp.down_proj.8bit_weight shape: [3072, 8192] +74: model.layers.13.mlp.down_proj.scale shape: [3072] +75: model.layers.13.mlp.down_proj.zeroPoint shape: [3072] +76: model.layers.13.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +77: model.layers.13.mlp.gate_up_proj.scale shape: [16384] +78: model.layers.13.mlp.gate_up_proj.zeroPoint shape: [16384] +79: model.layers.13.post_attention_layernorm.weight shape: [3072] +80: model.layers.13.self_attn.o_proj.8bit_weight shape: [3072, 3072] +81: model.layers.13.self_attn.o_proj.scale shape: [3072] +82: model.layers.13.self_attn.o_proj.zeroPoint shape: [3072] +83: model.layers.13.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +84: model.layers.13.self_attn.qkv_proj.scale shape: [9216] +85: model.layers.13.self_attn.qkv_proj.zeroPoint shape: [9216] +86: model.layers.14.input_layernorm.weight shape: [3072] +87: model.layers.14.mlp.down_proj.8bit_weight shape: [3072, 8192] +88: model.layers.14.mlp.down_proj.scale shape: [3072] +89: model.layers.14.mlp.down_proj.zeroPoint shape: [3072] +90: model.layers.14.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +91: model.layers.14.mlp.gate_up_proj.scale shape: [16384] +92: model.layers.14.mlp.gate_up_proj.zeroPoint shape: [16384] +93: model.layers.14.post_attention_layernorm.weight shape: [3072] +94: model.layers.14.self_attn.o_proj.8bit_weight shape: [3072, 3072] +95: model.layers.14.self_attn.o_proj.scale shape: [3072] +96: model.layers.14.self_attn.o_proj.zeroPoint shape: [3072] +97: model.layers.14.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +98: model.layers.14.self_attn.qkv_proj.scale shape: [9216] +99: model.layers.14.self_attn.qkv_proj.zeroPoint shape: [9216] +100: model.layers.15.input_layernorm.weight shape: [3072] +101: model.layers.15.mlp.down_proj.8bit_weight shape: [3072, 8192] +102: model.layers.15.mlp.down_proj.scale shape: [3072] +103: model.layers.15.mlp.down_proj.zeroPoint shape: [3072] +104: model.layers.15.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +105: model.layers.15.mlp.gate_up_proj.scale shape: [16384] +106: model.layers.15.mlp.gate_up_proj.zeroPoint shape: [16384] +107: model.layers.15.post_attention_layernorm.weight shape: [3072] +108: model.layers.15.self_attn.o_proj.8bit_weight shape: [3072, 3072] +109: model.layers.15.self_attn.o_proj.scale shape: [3072] +110: model.layers.15.self_attn.o_proj.zeroPoint shape: [3072] +111: model.layers.15.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +112: model.layers.15.self_attn.qkv_proj.scale shape: [9216] +113: model.layers.15.self_attn.qkv_proj.zeroPoint shape: [9216] +114: model.layers.16.input_layernorm.weight shape: [3072] +115: model.layers.16.mlp.down_proj.8bit_weight shape: [3072, 8192] +116: model.layers.16.mlp.down_proj.scale shape: [3072] +117: model.layers.16.mlp.down_proj.zeroPoint shape: [3072] +118: model.layers.16.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +119: model.layers.16.mlp.gate_up_proj.scale shape: [16384] +120: model.layers.16.mlp.gate_up_proj.zeroPoint shape: [16384] +121: model.layers.16.post_attention_layernorm.weight shape: [3072] +122: model.layers.16.self_attn.o_proj.8bit_weight shape: [3072, 3072] +123: model.layers.16.self_attn.o_proj.scale shape: [3072] +124: model.layers.16.self_attn.o_proj.zeroPoint shape: [3072] +125: model.layers.16.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +126: model.layers.16.self_attn.qkv_proj.scale shape: [9216] +127: model.layers.16.self_attn.qkv_proj.zeroPoint shape: [9216] +128: model.layers.17.input_layernorm.weight shape: [3072] +129: model.layers.17.mlp.down_proj.8bit_weight shape: [3072, 8192] +130: model.layers.17.mlp.down_proj.scale shape: [3072] +131: model.layers.17.mlp.down_proj.zeroPoint shape: [3072] +132: model.layers.17.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +133: model.layers.17.mlp.gate_up_proj.scale shape: [16384] +134: model.layers.17.mlp.gate_up_proj.zeroPoint shape: [16384] +135: model.layers.17.post_attention_layernorm.weight shape: [3072] +136: model.layers.17.self_attn.o_proj.8bit_weight shape: [3072, 3072] +137: model.layers.17.self_attn.o_proj.scale shape: [3072] +138: model.layers.17.self_attn.o_proj.zeroPoint shape: [3072] +139: model.layers.17.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +140: model.layers.17.self_attn.qkv_proj.scale shape: [9216] +141: model.layers.17.self_attn.qkv_proj.zeroPoint shape: [9216] +142: model.layers.18.input_layernorm.weight shape: [3072] +143: model.layers.18.mlp.down_proj.8bit_weight shape: [3072, 8192] +144: model.layers.18.mlp.down_proj.scale shape: [3072] +145: model.layers.18.mlp.down_proj.zeroPoint shape: [3072] +146: model.layers.18.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +147: model.layers.18.mlp.gate_up_proj.scale shape: [16384] +148: model.layers.18.mlp.gate_up_proj.zeroPoint shape: [16384] +149: model.layers.18.post_attention_layernorm.weight shape: [3072] +150: model.layers.18.self_attn.o_proj.8bit_weight shape: [3072, 3072] +151: model.layers.18.self_attn.o_proj.scale shape: [3072] +152: model.layers.18.self_attn.o_proj.zeroPoint shape: [3072] +153: model.layers.18.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +154: model.layers.18.self_attn.qkv_proj.scale shape: [9216] +155: model.layers.18.self_attn.qkv_proj.zeroPoint shape: [9216] +156: model.layers.19.input_layernorm.weight shape: [3072] +157: model.layers.19.mlp.down_proj.8bit_weight shape: [3072, 8192] +158: model.layers.19.mlp.down_proj.scale shape: [3072] +159: model.layers.19.mlp.down_proj.zeroPoint shape: [3072] +160: model.layers.19.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +161: model.layers.19.mlp.gate_up_proj.scale shape: [16384] +162: model.layers.19.mlp.gate_up_proj.zeroPoint shape: [16384] +163: model.layers.19.post_attention_layernorm.weight shape: [3072] +164: model.layers.19.self_attn.o_proj.8bit_weight shape: [3072, 3072] +165: model.layers.19.self_attn.o_proj.scale shape: [3072] +166: model.layers.19.self_attn.o_proj.zeroPoint shape: [3072] +167: model.layers.19.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +168: model.layers.19.self_attn.qkv_proj.scale shape: [9216] +169: model.layers.19.self_attn.qkv_proj.zeroPoint shape: [9216] +170: model.layers.2.input_layernorm.weight shape: [3072] +171: model.layers.2.mlp.down_proj.8bit_weight shape: [3072, 8192] +172: model.layers.2.mlp.down_proj.scale shape: [3072] +173: model.layers.2.mlp.down_proj.zeroPoint shape: [3072] +174: model.layers.2.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +175: model.layers.2.mlp.gate_up_proj.scale shape: [16384] +176: model.layers.2.mlp.gate_up_proj.zeroPoint shape: [16384] +177: model.layers.2.post_attention_layernorm.weight shape: [3072] +178: model.layers.2.self_attn.o_proj.8bit_weight shape: [3072, 3072] +179: model.layers.2.self_attn.o_proj.scale shape: [3072] +180: model.layers.2.self_attn.o_proj.zeroPoint shape: [3072] +181: model.layers.2.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +182: model.layers.2.self_attn.qkv_proj.scale shape: [9216] +183: model.layers.2.self_attn.qkv_proj.zeroPoint shape: [9216] +184: model.layers.20.input_layernorm.weight shape: [3072] +185: model.layers.20.mlp.down_proj.8bit_weight shape: [3072, 8192] +186: model.layers.20.mlp.down_proj.scale shape: [3072] +187: model.layers.20.mlp.down_proj.zeroPoint shape: [3072] +188: model.layers.20.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +189: model.layers.20.mlp.gate_up_proj.scale shape: [16384] +190: model.layers.20.mlp.gate_up_proj.zeroPoint shape: [16384] +191: model.layers.20.post_attention_layernorm.weight shape: [3072] +192: model.layers.20.self_attn.o_proj.8bit_weight shape: [3072, 3072] +193: model.layers.20.self_attn.o_proj.scale shape: [3072] +194: model.layers.20.self_attn.o_proj.zeroPoint shape: [3072] +195: model.layers.20.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +196: model.layers.20.self_attn.qkv_proj.scale shape: [9216] +197: model.layers.20.self_attn.qkv_proj.zeroPoint shape: [9216] +198: model.layers.21.input_layernorm.weight shape: [3072] +199: model.layers.21.mlp.down_proj.8bit_weight shape: [3072, 8192] +200: model.layers.21.mlp.down_proj.scale shape: [3072] +201: model.layers.21.mlp.down_proj.zeroPoint shape: [3072] +202: model.layers.21.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +203: model.layers.21.mlp.gate_up_proj.scale shape: [16384] +204: model.layers.21.mlp.gate_up_proj.zeroPoint shape: [16384] +205: model.layers.21.post_attention_layernorm.weight shape: [3072] +206: model.layers.21.self_attn.o_proj.8bit_weight shape: [3072, 3072] +207: model.layers.21.self_attn.o_proj.scale shape: [3072] +208: model.layers.21.self_attn.o_proj.zeroPoint shape: [3072] +209: model.layers.21.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +210: model.layers.21.self_attn.qkv_proj.scale shape: [9216] +211: model.layers.21.self_attn.qkv_proj.zeroPoint shape: [9216] +212: model.layers.22.input_layernorm.weight shape: [3072] +213: model.layers.22.mlp.down_proj.8bit_weight shape: [3072, 8192] +214: model.layers.22.mlp.down_proj.scale shape: [3072] +215: model.layers.22.mlp.down_proj.zeroPoint shape: [3072] +216: model.layers.22.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +217: model.layers.22.mlp.gate_up_proj.scale shape: [16384] +218: model.layers.22.mlp.gate_up_proj.zeroPoint shape: [16384] +219: model.layers.22.post_attention_layernorm.weight shape: [3072] +220: model.layers.22.self_attn.o_proj.8bit_weight shape: [3072, 3072] +221: model.layers.22.self_attn.o_proj.scale shape: [3072] +222: model.layers.22.self_attn.o_proj.zeroPoint shape: [3072] +223: model.layers.22.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +224: model.layers.22.self_attn.qkv_proj.scale shape: [9216] +225: model.layers.22.self_attn.qkv_proj.zeroPoint shape: [9216] +226: model.layers.23.input_layernorm.weight shape: [3072] +227: model.layers.23.mlp.down_proj.8bit_weight shape: [3072, 8192] +228: model.layers.23.mlp.down_proj.scale shape: [3072] +229: model.layers.23.mlp.down_proj.zeroPoint shape: [3072] +230: model.layers.23.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +231: model.layers.23.mlp.gate_up_proj.scale shape: [16384] +232: model.layers.23.mlp.gate_up_proj.zeroPoint shape: [16384] +233: model.layers.23.post_attention_layernorm.weight shape: [3072] +234: model.layers.23.self_attn.o_proj.8bit_weight shape: [3072, 3072] +235: model.layers.23.self_attn.o_proj.scale shape: [3072] +236: model.layers.23.self_attn.o_proj.zeroPoint shape: [3072] +237: model.layers.23.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +238: model.layers.23.self_attn.qkv_proj.scale shape: [9216] +239: model.layers.23.self_attn.qkv_proj.zeroPoint shape: [9216] +240: model.layers.24.input_layernorm.weight shape: [3072] +241: model.layers.24.mlp.down_proj.8bit_weight shape: [3072, 8192] +242: model.layers.24.mlp.down_proj.scale shape: [3072] +243: model.layers.24.mlp.down_proj.zeroPoint shape: [3072] +244: model.layers.24.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +245: model.layers.24.mlp.gate_up_proj.scale shape: [16384] +246: model.layers.24.mlp.gate_up_proj.zeroPoint shape: [16384] +247: model.layers.24.post_attention_layernorm.weight shape: [3072] +248: model.layers.24.self_attn.o_proj.8bit_weight shape: [3072, 3072] +249: model.layers.24.self_attn.o_proj.scale shape: [3072] +250: model.layers.24.self_attn.o_proj.zeroPoint shape: [3072] +251: model.layers.24.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +252: model.layers.24.self_attn.qkv_proj.scale shape: [9216] +253: model.layers.24.self_attn.qkv_proj.zeroPoint shape: [9216] +254: model.layers.25.input_layernorm.weight shape: [3072] +255: model.layers.25.mlp.down_proj.8bit_weight shape: [3072, 8192] +256: model.layers.25.mlp.down_proj.scale shape: [3072] +257: model.layers.25.mlp.down_proj.zeroPoint shape: [3072] +258: model.layers.25.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +259: model.layers.25.mlp.gate_up_proj.scale shape: [16384] +260: model.layers.25.mlp.gate_up_proj.zeroPoint shape: [16384] +261: model.layers.25.post_attention_layernorm.weight shape: [3072] +262: model.layers.25.self_attn.o_proj.8bit_weight shape: [3072, 3072] +263: model.layers.25.self_attn.o_proj.scale shape: [3072] +264: model.layers.25.self_attn.o_proj.zeroPoint shape: [3072] +265: model.layers.25.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +266: model.layers.25.self_attn.qkv_proj.scale shape: [9216] +267: model.layers.25.self_attn.qkv_proj.zeroPoint shape: [9216] +268: model.layers.26.input_layernorm.weight shape: [3072] +269: model.layers.26.mlp.down_proj.8bit_weight shape: [3072, 8192] +270: model.layers.26.mlp.down_proj.scale shape: [3072] +271: model.layers.26.mlp.down_proj.zeroPoint shape: [3072] +272: model.layers.26.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +273: model.layers.26.mlp.gate_up_proj.scale shape: [16384] +274: model.layers.26.mlp.gate_up_proj.zeroPoint shape: [16384] +275: model.layers.26.post_attention_layernorm.weight shape: [3072] +276: model.layers.26.self_attn.o_proj.8bit_weight shape: [3072, 3072] +277: model.layers.26.self_attn.o_proj.scale shape: [3072] +278: model.layers.26.self_attn.o_proj.zeroPoint shape: [3072] +279: model.layers.26.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +280: model.layers.26.self_attn.qkv_proj.scale shape: [9216] +281: model.layers.26.self_attn.qkv_proj.zeroPoint shape: [9216] +282: model.layers.27.input_layernorm.weight shape: [3072] +283: model.layers.27.mlp.down_proj.8bit_weight shape: [3072, 8192] +284: model.layers.27.mlp.down_proj.scale shape: [3072] +285: model.layers.27.mlp.down_proj.zeroPoint shape: [3072] +286: model.layers.27.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +287: model.layers.27.mlp.gate_up_proj.scale shape: [16384] +288: model.layers.27.mlp.gate_up_proj.zeroPoint shape: [16384] +289: model.layers.27.post_attention_layernorm.weight shape: [3072] +290: model.layers.27.self_attn.o_proj.8bit_weight shape: [3072, 3072] +291: model.layers.27.self_attn.o_proj.scale shape: [3072] +292: model.layers.27.self_attn.o_proj.zeroPoint shape: [3072] +293: model.layers.27.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +294: model.layers.27.self_attn.qkv_proj.scale shape: [9216] +295: model.layers.27.self_attn.qkv_proj.zeroPoint shape: [9216] +296: model.layers.28.input_layernorm.weight shape: [3072] +297: model.layers.28.mlp.down_proj.8bit_weight shape: [3072, 8192] +298: model.layers.28.mlp.down_proj.scale shape: [3072] +299: model.layers.28.mlp.down_proj.zeroPoint shape: [3072] +300: model.layers.28.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +301: model.layers.28.mlp.gate_up_proj.scale shape: [16384] +302: model.layers.28.mlp.gate_up_proj.zeroPoint shape: [16384] +303: model.layers.28.post_attention_layernorm.weight shape: [3072] +304: model.layers.28.self_attn.o_proj.8bit_weight shape: [3072, 3072] +305: model.layers.28.self_attn.o_proj.scale shape: [3072] +306: model.layers.28.self_attn.o_proj.zeroPoint shape: [3072] +307: model.layers.28.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +308: model.layers.28.self_attn.qkv_proj.scale shape: [9216] +309: model.layers.28.self_attn.qkv_proj.zeroPoint shape: [9216] +310: model.layers.29.input_layernorm.weight shape: [3072] +311: model.layers.29.mlp.down_proj.8bit_weight shape: [3072, 8192] +312: model.layers.29.mlp.down_proj.scale shape: [3072] +313: model.layers.29.mlp.down_proj.zeroPoint shape: [3072] +314: model.layers.29.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +315: model.layers.29.mlp.gate_up_proj.scale shape: [16384] +316: model.layers.29.mlp.gate_up_proj.zeroPoint shape: [16384] +317: model.layers.29.post_attention_layernorm.weight shape: [3072] +318: model.layers.29.self_attn.o_proj.8bit_weight shape: [3072, 3072] +319: model.layers.29.self_attn.o_proj.scale shape: [3072] +320: model.layers.29.self_attn.o_proj.zeroPoint shape: [3072] +321: model.layers.29.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +322: model.layers.29.self_attn.qkv_proj.scale shape: [9216] +323: model.layers.29.self_attn.qkv_proj.zeroPoint shape: [9216] +324: model.layers.3.input_layernorm.weight shape: [3072] +325: model.layers.3.mlp.down_proj.8bit_weight shape: [3072, 8192] +326: model.layers.3.mlp.down_proj.scale shape: [3072] +327: model.layers.3.mlp.down_proj.zeroPoint shape: [3072] +328: model.layers.3.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +329: model.layers.3.mlp.gate_up_proj.scale shape: [16384] +330: model.layers.3.mlp.gate_up_proj.zeroPoint shape: [16384] +331: model.layers.3.post_attention_layernorm.weight shape: [3072] +332: model.layers.3.self_attn.o_proj.8bit_weight shape: [3072, 3072] +333: model.layers.3.self_attn.o_proj.scale shape: [3072] +334: model.layers.3.self_attn.o_proj.zeroPoint shape: [3072] +335: model.layers.3.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +336: model.layers.3.self_attn.qkv_proj.scale shape: [9216] +337: model.layers.3.self_attn.qkv_proj.zeroPoint shape: [9216] +338: model.layers.30.input_layernorm.weight shape: [3072] +339: model.layers.30.mlp.down_proj.8bit_weight shape: [3072, 8192] +340: model.layers.30.mlp.down_proj.scale shape: [3072] +341: model.layers.30.mlp.down_proj.zeroPoint shape: [3072] +342: model.layers.30.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +343: model.layers.30.mlp.gate_up_proj.scale shape: [16384] +344: model.layers.30.mlp.gate_up_proj.zeroPoint shape: [16384] +345: model.layers.30.post_attention_layernorm.weight shape: [3072] +346: model.layers.30.self_attn.o_proj.8bit_weight shape: [3072, 3072] +347: model.layers.30.self_attn.o_proj.scale shape: [3072] +348: model.layers.30.self_attn.o_proj.zeroPoint shape: [3072] +349: model.layers.30.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +350: model.layers.30.self_attn.qkv_proj.scale shape: [9216] +351: model.layers.30.self_attn.qkv_proj.zeroPoint shape: [9216] +352: model.layers.31.input_layernorm.weight shape: [3072] +353: model.layers.31.mlp.down_proj.8bit_weight shape: [3072, 8192] +354: model.layers.31.mlp.down_proj.scale shape: [3072] +355: model.layers.31.mlp.down_proj.zeroPoint shape: [3072] +356: model.layers.31.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +357: model.layers.31.mlp.gate_up_proj.scale shape: [16384] +358: model.layers.31.mlp.gate_up_proj.zeroPoint shape: [16384] +359: model.layers.31.post_attention_layernorm.weight shape: [3072] +360: model.layers.31.self_attn.o_proj.8bit_weight shape: [3072, 3072] +361: model.layers.31.self_attn.o_proj.scale shape: [3072] +362: model.layers.31.self_attn.o_proj.zeroPoint shape: [3072] +363: model.layers.31.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +364: model.layers.31.self_attn.qkv_proj.scale shape: [9216] +365: model.layers.31.self_attn.qkv_proj.zeroPoint shape: [9216] +366: model.layers.4.input_layernorm.weight shape: [3072] +367: model.layers.4.mlp.down_proj.8bit_weight shape: [3072, 8192] +368: model.layers.4.mlp.down_proj.scale shape: [3072] +369: model.layers.4.mlp.down_proj.zeroPoint shape: [3072] +370: model.layers.4.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +371: model.layers.4.mlp.gate_up_proj.scale shape: [16384] +372: model.layers.4.mlp.gate_up_proj.zeroPoint shape: [16384] +373: model.layers.4.post_attention_layernorm.weight shape: [3072] +374: model.layers.4.self_attn.o_proj.8bit_weight shape: [3072, 3072] +375: model.layers.4.self_attn.o_proj.scale shape: [3072] +376: model.layers.4.self_attn.o_proj.zeroPoint shape: [3072] +377: model.layers.4.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +378: model.layers.4.self_attn.qkv_proj.scale shape: [9216] +379: model.layers.4.self_attn.qkv_proj.zeroPoint shape: [9216] +380: model.layers.5.input_layernorm.weight shape: [3072] +381: model.layers.5.mlp.down_proj.8bit_weight shape: [3072, 8192] +382: model.layers.5.mlp.down_proj.scale shape: [3072] +383: model.layers.5.mlp.down_proj.zeroPoint shape: [3072] +384: model.layers.5.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +385: model.layers.5.mlp.gate_up_proj.scale shape: [16384] +386: model.layers.5.mlp.gate_up_proj.zeroPoint shape: [16384] +387: model.layers.5.post_attention_layernorm.weight shape: [3072] +388: model.layers.5.self_attn.o_proj.8bit_weight shape: [3072, 3072] +389: model.layers.5.self_attn.o_proj.scale shape: [3072] +390: model.layers.5.self_attn.o_proj.zeroPoint shape: [3072] +391: model.layers.5.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +392: model.layers.5.self_attn.qkv_proj.scale shape: [9216] +393: model.layers.5.self_attn.qkv_proj.zeroPoint shape: [9216] +394: model.layers.6.input_layernorm.weight shape: [3072] +395: model.layers.6.mlp.down_proj.8bit_weight shape: [3072, 8192] +396: model.layers.6.mlp.down_proj.scale shape: [3072] +397: model.layers.6.mlp.down_proj.zeroPoint shape: [3072] +398: model.layers.6.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +399: model.layers.6.mlp.gate_up_proj.scale shape: [16384] +400: model.layers.6.mlp.gate_up_proj.zeroPoint shape: [16384] +401: model.layers.6.post_attention_layernorm.weight shape: [3072] +402: model.layers.6.self_attn.o_proj.8bit_weight shape: [3072, 3072] +403: model.layers.6.self_attn.o_proj.scale shape: [3072] +404: model.layers.6.self_attn.o_proj.zeroPoint shape: [3072] +405: model.layers.6.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +406: model.layers.6.self_attn.qkv_proj.scale shape: [9216] +407: model.layers.6.self_attn.qkv_proj.zeroPoint shape: [9216] +408: model.layers.7.input_layernorm.weight shape: [3072] +409: model.layers.7.mlp.down_proj.8bit_weight shape: [3072, 8192] +410: model.layers.7.mlp.down_proj.scale shape: [3072] +411: model.layers.7.mlp.down_proj.zeroPoint shape: [3072] +412: model.layers.7.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +413: model.layers.7.mlp.gate_up_proj.scale shape: [16384] +414: model.layers.7.mlp.gate_up_proj.zeroPoint shape: [16384] +415: model.layers.7.post_attention_layernorm.weight shape: [3072] +416: model.layers.7.self_attn.o_proj.8bit_weight shape: [3072, 3072] +417: model.layers.7.self_attn.o_proj.scale shape: [3072] +418: model.layers.7.self_attn.o_proj.zeroPoint shape: [3072] +419: model.layers.7.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +420: model.layers.7.self_attn.qkv_proj.scale shape: [9216] +421: model.layers.7.self_attn.qkv_proj.zeroPoint shape: [9216] +422: model.layers.8.input_layernorm.weight shape: [3072] +423: model.layers.8.mlp.down_proj.8bit_weight shape: [3072, 8192] +424: model.layers.8.mlp.down_proj.scale shape: [3072] +425: model.layers.8.mlp.down_proj.zeroPoint shape: [3072] +426: model.layers.8.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +427: model.layers.8.mlp.gate_up_proj.scale shape: [16384] +428: model.layers.8.mlp.gate_up_proj.zeroPoint shape: [16384] +429: model.layers.8.post_attention_layernorm.weight shape: [3072] +430: model.layers.8.self_attn.o_proj.8bit_weight shape: [3072, 3072] +431: model.layers.8.self_attn.o_proj.scale shape: [3072] +432: model.layers.8.self_attn.o_proj.zeroPoint shape: [3072] +433: model.layers.8.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +434: model.layers.8.self_attn.qkv_proj.scale shape: [9216] +435: model.layers.8.self_attn.qkv_proj.zeroPoint shape: [9216] +436: model.layers.9.input_layernorm.weight shape: [3072] +437: model.layers.9.mlp.down_proj.8bit_weight shape: [3072, 8192] +438: model.layers.9.mlp.down_proj.scale shape: [3072] +439: model.layers.9.mlp.down_proj.zeroPoint shape: [3072] +440: model.layers.9.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +441: model.layers.9.mlp.gate_up_proj.scale shape: [16384] +442: model.layers.9.mlp.gate_up_proj.zeroPoint shape: [16384] +443: model.layers.9.post_attention_layernorm.weight shape: [3072] +444: model.layers.9.self_attn.o_proj.8bit_weight shape: [3072, 3072] +445: model.layers.9.self_attn.o_proj.scale shape: [3072] +446: model.layers.9.self_attn.o_proj.zeroPoint shape: [3072] +447: model.layers.9.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +448: model.layers.9.self_attn.qkv_proj.scale shape: [9216] +449: model.layers.9.self_attn.qkv_proj.zeroPoint shape: [9216] +450: model.norm.weight shape: [3072] diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Mini4KInt8QuantizeShapeTest.received.txt b/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Mini4KInt8QuantizeShapeTest.received.txt new file mode 100644 index 0000000000..d9c51c94a9 --- /dev/null +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Mini4KInt8QuantizeShapeTest.received.txt @@ -0,0 +1,451 @@ +0: lm_head.weight shape: [32064, 3072] +1: model.embed_tokens.weight shape: [32064, 3072] +2: model.layers.0.input_layernorm.weight shape: [3072] +3: model.layers.0.mlp.down_proj.8bit_weight shape: [3072, 8192] +4: model.layers.0.mlp.down_proj.scale shape: [3072] +5: model.layers.0.mlp.down_proj.zeroPoint shape: [3072] +6: model.layers.0.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +7: model.layers.0.mlp.gate_up_proj.scale shape: [16384] +8: model.layers.0.mlp.gate_up_proj.zeroPoint shape: [16384] +9: model.layers.0.post_attention_layernorm.weight shape: [3072] +10: model.layers.0.self_attn.o_proj.8bit_weight shape: [3072, 3072] +11: model.layers.0.self_attn.o_proj.scale shape: [3072] +12: model.layers.0.self_attn.o_proj.zeroPoint shape: [3072] +13: model.layers.0.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +14: model.layers.0.self_attn.qkv_proj.scale shape: [9216] +15: model.layers.0.self_attn.qkv_proj.zeroPoint shape: [9216] +16: model.layers.1.input_layernorm.weight shape: [3072] +17: model.layers.1.mlp.down_proj.8bit_weight shape: [3072, 8192] +18: model.layers.1.mlp.down_proj.scale shape: [3072] +19: model.layers.1.mlp.down_proj.zeroPoint shape: [3072] +20: model.layers.1.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +21: model.layers.1.mlp.gate_up_proj.scale shape: [16384] +22: model.layers.1.mlp.gate_up_proj.zeroPoint shape: [16384] +23: model.layers.1.post_attention_layernorm.weight shape: [3072] +24: model.layers.1.self_attn.o_proj.8bit_weight shape: [3072, 3072] +25: model.layers.1.self_attn.o_proj.scale shape: [3072] +26: model.layers.1.self_attn.o_proj.zeroPoint shape: [3072] +27: model.layers.1.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +28: model.layers.1.self_attn.qkv_proj.scale shape: [9216] +29: model.layers.1.self_attn.qkv_proj.zeroPoint shape: [9216] +30: model.layers.10.input_layernorm.weight shape: [3072] +31: model.layers.10.mlp.down_proj.8bit_weight shape: [3072, 8192] +32: model.layers.10.mlp.down_proj.scale shape: [3072] +33: model.layers.10.mlp.down_proj.zeroPoint shape: [3072] +34: model.layers.10.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +35: model.layers.10.mlp.gate_up_proj.scale shape: [16384] +36: model.layers.10.mlp.gate_up_proj.zeroPoint shape: [16384] +37: model.layers.10.post_attention_layernorm.weight shape: [3072] +38: model.layers.10.self_attn.o_proj.8bit_weight shape: [3072, 3072] +39: model.layers.10.self_attn.o_proj.scale shape: [3072] +40: model.layers.10.self_attn.o_proj.zeroPoint shape: [3072] +41: model.layers.10.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +42: model.layers.10.self_attn.qkv_proj.scale shape: [9216] +43: model.layers.10.self_attn.qkv_proj.zeroPoint shape: [9216] +44: model.layers.11.input_layernorm.weight shape: [3072] +45: model.layers.11.mlp.down_proj.8bit_weight shape: [3072, 8192] +46: model.layers.11.mlp.down_proj.scale shape: [3072] +47: model.layers.11.mlp.down_proj.zeroPoint shape: [3072] +48: model.layers.11.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +49: model.layers.11.mlp.gate_up_proj.scale shape: [16384] +50: model.layers.11.mlp.gate_up_proj.zeroPoint shape: [16384] +51: model.layers.11.post_attention_layernorm.weight shape: [3072] +52: model.layers.11.self_attn.o_proj.8bit_weight shape: [3072, 3072] +53: model.layers.11.self_attn.o_proj.scale shape: [3072] +54: model.layers.11.self_attn.o_proj.zeroPoint shape: [3072] +55: model.layers.11.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +56: model.layers.11.self_attn.qkv_proj.scale shape: [9216] +57: model.layers.11.self_attn.qkv_proj.zeroPoint shape: [9216] +58: model.layers.12.input_layernorm.weight shape: [3072] +59: model.layers.12.mlp.down_proj.8bit_weight shape: [3072, 8192] +60: model.layers.12.mlp.down_proj.scale shape: [3072] +61: model.layers.12.mlp.down_proj.zeroPoint shape: [3072] +62: model.layers.12.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +63: model.layers.12.mlp.gate_up_proj.scale shape: [16384] +64: model.layers.12.mlp.gate_up_proj.zeroPoint shape: [16384] +65: model.layers.12.post_attention_layernorm.weight shape: [3072] +66: model.layers.12.self_attn.o_proj.8bit_weight shape: [3072, 3072] +67: model.layers.12.self_attn.o_proj.scale shape: [3072] +68: model.layers.12.self_attn.o_proj.zeroPoint shape: [3072] +69: model.layers.12.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +70: model.layers.12.self_attn.qkv_proj.scale shape: [9216] +71: model.layers.12.self_attn.qkv_proj.zeroPoint shape: [9216] +72: model.layers.13.input_layernorm.weight shape: [3072] +73: model.layers.13.mlp.down_proj.8bit_weight shape: [3072, 8192] +74: model.layers.13.mlp.down_proj.scale shape: [3072] +75: model.layers.13.mlp.down_proj.zeroPoint shape: [3072] +76: model.layers.13.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +77: model.layers.13.mlp.gate_up_proj.scale shape: [16384] +78: model.layers.13.mlp.gate_up_proj.zeroPoint shape: [16384] +79: model.layers.13.post_attention_layernorm.weight shape: [3072] +80: model.layers.13.self_attn.o_proj.8bit_weight shape: [3072, 3072] +81: model.layers.13.self_attn.o_proj.scale shape: [3072] +82: model.layers.13.self_attn.o_proj.zeroPoint shape: [3072] +83: model.layers.13.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +84: model.layers.13.self_attn.qkv_proj.scale shape: [9216] +85: model.layers.13.self_attn.qkv_proj.zeroPoint shape: [9216] +86: model.layers.14.input_layernorm.weight shape: [3072] +87: model.layers.14.mlp.down_proj.8bit_weight shape: [3072, 8192] +88: model.layers.14.mlp.down_proj.scale shape: [3072] +89: model.layers.14.mlp.down_proj.zeroPoint shape: [3072] +90: model.layers.14.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +91: model.layers.14.mlp.gate_up_proj.scale shape: [16384] +92: model.layers.14.mlp.gate_up_proj.zeroPoint shape: [16384] +93: model.layers.14.post_attention_layernorm.weight shape: [3072] +94: model.layers.14.self_attn.o_proj.8bit_weight shape: [3072, 3072] +95: model.layers.14.self_attn.o_proj.scale shape: [3072] +96: model.layers.14.self_attn.o_proj.zeroPoint shape: [3072] +97: model.layers.14.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +98: model.layers.14.self_attn.qkv_proj.scale shape: [9216] +99: model.layers.14.self_attn.qkv_proj.zeroPoint shape: [9216] +100: model.layers.15.input_layernorm.weight shape: [3072] +101: model.layers.15.mlp.down_proj.8bit_weight shape: [3072, 8192] +102: model.layers.15.mlp.down_proj.scale shape: [3072] +103: model.layers.15.mlp.down_proj.zeroPoint shape: [3072] +104: model.layers.15.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +105: model.layers.15.mlp.gate_up_proj.scale shape: [16384] +106: model.layers.15.mlp.gate_up_proj.zeroPoint shape: [16384] +107: model.layers.15.post_attention_layernorm.weight shape: [3072] +108: model.layers.15.self_attn.o_proj.8bit_weight shape: [3072, 3072] +109: model.layers.15.self_attn.o_proj.scale shape: [3072] +110: model.layers.15.self_attn.o_proj.zeroPoint shape: [3072] +111: model.layers.15.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +112: model.layers.15.self_attn.qkv_proj.scale shape: [9216] +113: model.layers.15.self_attn.qkv_proj.zeroPoint shape: [9216] +114: model.layers.16.input_layernorm.weight shape: [3072] +115: model.layers.16.mlp.down_proj.8bit_weight shape: [3072, 8192] +116: model.layers.16.mlp.down_proj.scale shape: [3072] +117: model.layers.16.mlp.down_proj.zeroPoint shape: [3072] +118: model.layers.16.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +119: model.layers.16.mlp.gate_up_proj.scale shape: [16384] +120: model.layers.16.mlp.gate_up_proj.zeroPoint shape: [16384] +121: model.layers.16.post_attention_layernorm.weight shape: [3072] +122: model.layers.16.self_attn.o_proj.8bit_weight shape: [3072, 3072] +123: model.layers.16.self_attn.o_proj.scale shape: [3072] +124: model.layers.16.self_attn.o_proj.zeroPoint shape: [3072] +125: model.layers.16.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +126: model.layers.16.self_attn.qkv_proj.scale shape: [9216] +127: model.layers.16.self_attn.qkv_proj.zeroPoint shape: [9216] +128: model.layers.17.input_layernorm.weight shape: [3072] +129: model.layers.17.mlp.down_proj.8bit_weight shape: [3072, 8192] +130: model.layers.17.mlp.down_proj.scale shape: [3072] +131: model.layers.17.mlp.down_proj.zeroPoint shape: [3072] +132: model.layers.17.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +133: model.layers.17.mlp.gate_up_proj.scale shape: [16384] +134: model.layers.17.mlp.gate_up_proj.zeroPoint shape: [16384] +135: model.layers.17.post_attention_layernorm.weight shape: [3072] +136: model.layers.17.self_attn.o_proj.8bit_weight shape: [3072, 3072] +137: model.layers.17.self_attn.o_proj.scale shape: [3072] +138: model.layers.17.self_attn.o_proj.zeroPoint shape: [3072] +139: model.layers.17.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +140: model.layers.17.self_attn.qkv_proj.scale shape: [9216] +141: model.layers.17.self_attn.qkv_proj.zeroPoint shape: [9216] +142: model.layers.18.input_layernorm.weight shape: [3072] +143: model.layers.18.mlp.down_proj.8bit_weight shape: [3072, 8192] +144: model.layers.18.mlp.down_proj.scale shape: [3072] +145: model.layers.18.mlp.down_proj.zeroPoint shape: [3072] +146: model.layers.18.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +147: model.layers.18.mlp.gate_up_proj.scale shape: [16384] +148: model.layers.18.mlp.gate_up_proj.zeroPoint shape: [16384] +149: model.layers.18.post_attention_layernorm.weight shape: [3072] +150: model.layers.18.self_attn.o_proj.8bit_weight shape: [3072, 3072] +151: model.layers.18.self_attn.o_proj.scale shape: [3072] +152: model.layers.18.self_attn.o_proj.zeroPoint shape: [3072] +153: model.layers.18.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +154: model.layers.18.self_attn.qkv_proj.scale shape: [9216] +155: model.layers.18.self_attn.qkv_proj.zeroPoint shape: [9216] +156: model.layers.19.input_layernorm.weight shape: [3072] +157: model.layers.19.mlp.down_proj.8bit_weight shape: [3072, 8192] +158: model.layers.19.mlp.down_proj.scale shape: [3072] +159: model.layers.19.mlp.down_proj.zeroPoint shape: [3072] +160: model.layers.19.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +161: model.layers.19.mlp.gate_up_proj.scale shape: [16384] +162: model.layers.19.mlp.gate_up_proj.zeroPoint shape: [16384] +163: model.layers.19.post_attention_layernorm.weight shape: [3072] +164: model.layers.19.self_attn.o_proj.8bit_weight shape: [3072, 3072] +165: model.layers.19.self_attn.o_proj.scale shape: [3072] +166: model.layers.19.self_attn.o_proj.zeroPoint shape: [3072] +167: model.layers.19.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +168: model.layers.19.self_attn.qkv_proj.scale shape: [9216] +169: model.layers.19.self_attn.qkv_proj.zeroPoint shape: [9216] +170: model.layers.2.input_layernorm.weight shape: [3072] +171: model.layers.2.mlp.down_proj.8bit_weight shape: [3072, 8192] +172: model.layers.2.mlp.down_proj.scale shape: [3072] +173: model.layers.2.mlp.down_proj.zeroPoint shape: [3072] +174: model.layers.2.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +175: model.layers.2.mlp.gate_up_proj.scale shape: [16384] +176: model.layers.2.mlp.gate_up_proj.zeroPoint shape: [16384] +177: model.layers.2.post_attention_layernorm.weight shape: [3072] +178: model.layers.2.self_attn.o_proj.8bit_weight shape: [3072, 3072] +179: model.layers.2.self_attn.o_proj.scale shape: [3072] +180: model.layers.2.self_attn.o_proj.zeroPoint shape: [3072] +181: model.layers.2.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +182: model.layers.2.self_attn.qkv_proj.scale shape: [9216] +183: model.layers.2.self_attn.qkv_proj.zeroPoint shape: [9216] +184: model.layers.20.input_layernorm.weight shape: [3072] +185: model.layers.20.mlp.down_proj.8bit_weight shape: [3072, 8192] +186: model.layers.20.mlp.down_proj.scale shape: [3072] +187: model.layers.20.mlp.down_proj.zeroPoint shape: [3072] +188: model.layers.20.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +189: model.layers.20.mlp.gate_up_proj.scale shape: [16384] +190: model.layers.20.mlp.gate_up_proj.zeroPoint shape: [16384] +191: model.layers.20.post_attention_layernorm.weight shape: [3072] +192: model.layers.20.self_attn.o_proj.8bit_weight shape: [3072, 3072] +193: model.layers.20.self_attn.o_proj.scale shape: [3072] +194: model.layers.20.self_attn.o_proj.zeroPoint shape: [3072] +195: model.layers.20.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +196: model.layers.20.self_attn.qkv_proj.scale shape: [9216] +197: model.layers.20.self_attn.qkv_proj.zeroPoint shape: [9216] +198: model.layers.21.input_layernorm.weight shape: [3072] +199: model.layers.21.mlp.down_proj.8bit_weight shape: [3072, 8192] +200: model.layers.21.mlp.down_proj.scale shape: [3072] +201: model.layers.21.mlp.down_proj.zeroPoint shape: [3072] +202: model.layers.21.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +203: model.layers.21.mlp.gate_up_proj.scale shape: [16384] +204: model.layers.21.mlp.gate_up_proj.zeroPoint shape: [16384] +205: model.layers.21.post_attention_layernorm.weight shape: [3072] +206: model.layers.21.self_attn.o_proj.8bit_weight shape: [3072, 3072] +207: model.layers.21.self_attn.o_proj.scale shape: [3072] +208: model.layers.21.self_attn.o_proj.zeroPoint shape: [3072] +209: model.layers.21.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +210: model.layers.21.self_attn.qkv_proj.scale shape: [9216] +211: model.layers.21.self_attn.qkv_proj.zeroPoint shape: [9216] +212: model.layers.22.input_layernorm.weight shape: [3072] +213: model.layers.22.mlp.down_proj.8bit_weight shape: [3072, 8192] +214: model.layers.22.mlp.down_proj.scale shape: [3072] +215: model.layers.22.mlp.down_proj.zeroPoint shape: [3072] +216: model.layers.22.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +217: model.layers.22.mlp.gate_up_proj.scale shape: [16384] +218: model.layers.22.mlp.gate_up_proj.zeroPoint shape: [16384] +219: model.layers.22.post_attention_layernorm.weight shape: [3072] +220: model.layers.22.self_attn.o_proj.8bit_weight shape: [3072, 3072] +221: model.layers.22.self_attn.o_proj.scale shape: [3072] +222: model.layers.22.self_attn.o_proj.zeroPoint shape: [3072] +223: model.layers.22.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +224: model.layers.22.self_attn.qkv_proj.scale shape: [9216] +225: model.layers.22.self_attn.qkv_proj.zeroPoint shape: [9216] +226: model.layers.23.input_layernorm.weight shape: [3072] +227: model.layers.23.mlp.down_proj.8bit_weight shape: [3072, 8192] +228: model.layers.23.mlp.down_proj.scale shape: [3072] +229: model.layers.23.mlp.down_proj.zeroPoint shape: [3072] +230: model.layers.23.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +231: model.layers.23.mlp.gate_up_proj.scale shape: [16384] +232: model.layers.23.mlp.gate_up_proj.zeroPoint shape: [16384] +233: model.layers.23.post_attention_layernorm.weight shape: [3072] +234: model.layers.23.self_attn.o_proj.8bit_weight shape: [3072, 3072] +235: model.layers.23.self_attn.o_proj.scale shape: [3072] +236: model.layers.23.self_attn.o_proj.zeroPoint shape: [3072] +237: model.layers.23.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +238: model.layers.23.self_attn.qkv_proj.scale shape: [9216] +239: model.layers.23.self_attn.qkv_proj.zeroPoint shape: [9216] +240: model.layers.24.input_layernorm.weight shape: [3072] +241: model.layers.24.mlp.down_proj.8bit_weight shape: [3072, 8192] +242: model.layers.24.mlp.down_proj.scale shape: [3072] +243: model.layers.24.mlp.down_proj.zeroPoint shape: [3072] +244: model.layers.24.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +245: model.layers.24.mlp.gate_up_proj.scale shape: [16384] +246: model.layers.24.mlp.gate_up_proj.zeroPoint shape: [16384] +247: model.layers.24.post_attention_layernorm.weight shape: [3072] +248: model.layers.24.self_attn.o_proj.8bit_weight shape: [3072, 3072] +249: model.layers.24.self_attn.o_proj.scale shape: [3072] +250: model.layers.24.self_attn.o_proj.zeroPoint shape: [3072] +251: model.layers.24.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +252: model.layers.24.self_attn.qkv_proj.scale shape: [9216] +253: model.layers.24.self_attn.qkv_proj.zeroPoint shape: [9216] +254: model.layers.25.input_layernorm.weight shape: [3072] +255: model.layers.25.mlp.down_proj.8bit_weight shape: [3072, 8192] +256: model.layers.25.mlp.down_proj.scale shape: [3072] +257: model.layers.25.mlp.down_proj.zeroPoint shape: [3072] +258: model.layers.25.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +259: model.layers.25.mlp.gate_up_proj.scale shape: [16384] +260: model.layers.25.mlp.gate_up_proj.zeroPoint shape: [16384] +261: model.layers.25.post_attention_layernorm.weight shape: [3072] +262: model.layers.25.self_attn.o_proj.8bit_weight shape: [3072, 3072] +263: model.layers.25.self_attn.o_proj.scale shape: [3072] +264: model.layers.25.self_attn.o_proj.zeroPoint shape: [3072] +265: model.layers.25.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +266: model.layers.25.self_attn.qkv_proj.scale shape: [9216] +267: model.layers.25.self_attn.qkv_proj.zeroPoint shape: [9216] +268: model.layers.26.input_layernorm.weight shape: [3072] +269: model.layers.26.mlp.down_proj.8bit_weight shape: [3072, 8192] +270: model.layers.26.mlp.down_proj.scale shape: [3072] +271: model.layers.26.mlp.down_proj.zeroPoint shape: [3072] +272: model.layers.26.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +273: model.layers.26.mlp.gate_up_proj.scale shape: [16384] +274: model.layers.26.mlp.gate_up_proj.zeroPoint shape: [16384] +275: model.layers.26.post_attention_layernorm.weight shape: [3072] +276: model.layers.26.self_attn.o_proj.8bit_weight shape: [3072, 3072] +277: model.layers.26.self_attn.o_proj.scale shape: [3072] +278: model.layers.26.self_attn.o_proj.zeroPoint shape: [3072] +279: model.layers.26.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +280: model.layers.26.self_attn.qkv_proj.scale shape: [9216] +281: model.layers.26.self_attn.qkv_proj.zeroPoint shape: [9216] +282: model.layers.27.input_layernorm.weight shape: [3072] +283: model.layers.27.mlp.down_proj.8bit_weight shape: [3072, 8192] +284: model.layers.27.mlp.down_proj.scale shape: [3072] +285: model.layers.27.mlp.down_proj.zeroPoint shape: [3072] +286: model.layers.27.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +287: model.layers.27.mlp.gate_up_proj.scale shape: [16384] +288: model.layers.27.mlp.gate_up_proj.zeroPoint shape: [16384] +289: model.layers.27.post_attention_layernorm.weight shape: [3072] +290: model.layers.27.self_attn.o_proj.8bit_weight shape: [3072, 3072] +291: model.layers.27.self_attn.o_proj.scale shape: [3072] +292: model.layers.27.self_attn.o_proj.zeroPoint shape: [3072] +293: model.layers.27.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +294: model.layers.27.self_attn.qkv_proj.scale shape: [9216] +295: model.layers.27.self_attn.qkv_proj.zeroPoint shape: [9216] +296: model.layers.28.input_layernorm.weight shape: [3072] +297: model.layers.28.mlp.down_proj.8bit_weight shape: [3072, 8192] +298: model.layers.28.mlp.down_proj.scale shape: [3072] +299: model.layers.28.mlp.down_proj.zeroPoint shape: [3072] +300: model.layers.28.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +301: model.layers.28.mlp.gate_up_proj.scale shape: [16384] +302: model.layers.28.mlp.gate_up_proj.zeroPoint shape: [16384] +303: model.layers.28.post_attention_layernorm.weight shape: [3072] +304: model.layers.28.self_attn.o_proj.8bit_weight shape: [3072, 3072] +305: model.layers.28.self_attn.o_proj.scale shape: [3072] +306: model.layers.28.self_attn.o_proj.zeroPoint shape: [3072] +307: model.layers.28.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +308: model.layers.28.self_attn.qkv_proj.scale shape: [9216] +309: model.layers.28.self_attn.qkv_proj.zeroPoint shape: [9216] +310: model.layers.29.input_layernorm.weight shape: [3072] +311: model.layers.29.mlp.down_proj.8bit_weight shape: [3072, 8192] +312: model.layers.29.mlp.down_proj.scale shape: [3072] +313: model.layers.29.mlp.down_proj.zeroPoint shape: [3072] +314: model.layers.29.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +315: model.layers.29.mlp.gate_up_proj.scale shape: [16384] +316: model.layers.29.mlp.gate_up_proj.zeroPoint shape: [16384] +317: model.layers.29.post_attention_layernorm.weight shape: [3072] +318: model.layers.29.self_attn.o_proj.8bit_weight shape: [3072, 3072] +319: model.layers.29.self_attn.o_proj.scale shape: [3072] +320: model.layers.29.self_attn.o_proj.zeroPoint shape: [3072] +321: model.layers.29.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +322: model.layers.29.self_attn.qkv_proj.scale shape: [9216] +323: model.layers.29.self_attn.qkv_proj.zeroPoint shape: [9216] +324: model.layers.3.input_layernorm.weight shape: [3072] +325: model.layers.3.mlp.down_proj.8bit_weight shape: [3072, 8192] +326: model.layers.3.mlp.down_proj.scale shape: [3072] +327: model.layers.3.mlp.down_proj.zeroPoint shape: [3072] +328: model.layers.3.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +329: model.layers.3.mlp.gate_up_proj.scale shape: [16384] +330: model.layers.3.mlp.gate_up_proj.zeroPoint shape: [16384] +331: model.layers.3.post_attention_layernorm.weight shape: [3072] +332: model.layers.3.self_attn.o_proj.8bit_weight shape: [3072, 3072] +333: model.layers.3.self_attn.o_proj.scale shape: [3072] +334: model.layers.3.self_attn.o_proj.zeroPoint shape: [3072] +335: model.layers.3.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +336: model.layers.3.self_attn.qkv_proj.scale shape: [9216] +337: model.layers.3.self_attn.qkv_proj.zeroPoint shape: [9216] +338: model.layers.30.input_layernorm.weight shape: [3072] +339: model.layers.30.mlp.down_proj.8bit_weight shape: [3072, 8192] +340: model.layers.30.mlp.down_proj.scale shape: [3072] +341: model.layers.30.mlp.down_proj.zeroPoint shape: [3072] +342: model.layers.30.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +343: model.layers.30.mlp.gate_up_proj.scale shape: [16384] +344: model.layers.30.mlp.gate_up_proj.zeroPoint shape: [16384] +345: model.layers.30.post_attention_layernorm.weight shape: [3072] +346: model.layers.30.self_attn.o_proj.8bit_weight shape: [3072, 3072] +347: model.layers.30.self_attn.o_proj.scale shape: [3072] +348: model.layers.30.self_attn.o_proj.zeroPoint shape: [3072] +349: model.layers.30.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +350: model.layers.30.self_attn.qkv_proj.scale shape: [9216] +351: model.layers.30.self_attn.qkv_proj.zeroPoint shape: [9216] +352: model.layers.31.input_layernorm.weight shape: [3072] +353: model.layers.31.mlp.down_proj.8bit_weight shape: [3072, 8192] +354: model.layers.31.mlp.down_proj.scale shape: [3072] +355: model.layers.31.mlp.down_proj.zeroPoint shape: [3072] +356: model.layers.31.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +357: model.layers.31.mlp.gate_up_proj.scale shape: [16384] +358: model.layers.31.mlp.gate_up_proj.zeroPoint shape: [16384] +359: model.layers.31.post_attention_layernorm.weight shape: [3072] +360: model.layers.31.self_attn.o_proj.8bit_weight shape: [3072, 3072] +361: model.layers.31.self_attn.o_proj.scale shape: [3072] +362: model.layers.31.self_attn.o_proj.zeroPoint shape: [3072] +363: model.layers.31.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +364: model.layers.31.self_attn.qkv_proj.scale shape: [9216] +365: model.layers.31.self_attn.qkv_proj.zeroPoint shape: [9216] +366: model.layers.4.input_layernorm.weight shape: [3072] +367: model.layers.4.mlp.down_proj.8bit_weight shape: [3072, 8192] +368: model.layers.4.mlp.down_proj.scale shape: [3072] +369: model.layers.4.mlp.down_proj.zeroPoint shape: [3072] +370: model.layers.4.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +371: model.layers.4.mlp.gate_up_proj.scale shape: [16384] +372: model.layers.4.mlp.gate_up_proj.zeroPoint shape: [16384] +373: model.layers.4.post_attention_layernorm.weight shape: [3072] +374: model.layers.4.self_attn.o_proj.8bit_weight shape: [3072, 3072] +375: model.layers.4.self_attn.o_proj.scale shape: [3072] +376: model.layers.4.self_attn.o_proj.zeroPoint shape: [3072] +377: model.layers.4.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +378: model.layers.4.self_attn.qkv_proj.scale shape: [9216] +379: model.layers.4.self_attn.qkv_proj.zeroPoint shape: [9216] +380: model.layers.5.input_layernorm.weight shape: [3072] +381: model.layers.5.mlp.down_proj.8bit_weight shape: [3072, 8192] +382: model.layers.5.mlp.down_proj.scale shape: [3072] +383: model.layers.5.mlp.down_proj.zeroPoint shape: [3072] +384: model.layers.5.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +385: model.layers.5.mlp.gate_up_proj.scale shape: [16384] +386: model.layers.5.mlp.gate_up_proj.zeroPoint shape: [16384] +387: model.layers.5.post_attention_layernorm.weight shape: [3072] +388: model.layers.5.self_attn.o_proj.8bit_weight shape: [3072, 3072] +389: model.layers.5.self_attn.o_proj.scale shape: [3072] +390: model.layers.5.self_attn.o_proj.zeroPoint shape: [3072] +391: model.layers.5.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +392: model.layers.5.self_attn.qkv_proj.scale shape: [9216] +393: model.layers.5.self_attn.qkv_proj.zeroPoint shape: [9216] +394: model.layers.6.input_layernorm.weight shape: [3072] +395: model.layers.6.mlp.down_proj.8bit_weight shape: [3072, 8192] +396: model.layers.6.mlp.down_proj.scale shape: [3072] +397: model.layers.6.mlp.down_proj.zeroPoint shape: [3072] +398: model.layers.6.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +399: model.layers.6.mlp.gate_up_proj.scale shape: [16384] +400: model.layers.6.mlp.gate_up_proj.zeroPoint shape: [16384] +401: model.layers.6.post_attention_layernorm.weight shape: [3072] +402: model.layers.6.self_attn.o_proj.8bit_weight shape: [3072, 3072] +403: model.layers.6.self_attn.o_proj.scale shape: [3072] +404: model.layers.6.self_attn.o_proj.zeroPoint shape: [3072] +405: model.layers.6.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +406: model.layers.6.self_attn.qkv_proj.scale shape: [9216] +407: model.layers.6.self_attn.qkv_proj.zeroPoint shape: [9216] +408: model.layers.7.input_layernorm.weight shape: [3072] +409: model.layers.7.mlp.down_proj.8bit_weight shape: [3072, 8192] +410: model.layers.7.mlp.down_proj.scale shape: [3072] +411: model.layers.7.mlp.down_proj.zeroPoint shape: [3072] +412: model.layers.7.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +413: model.layers.7.mlp.gate_up_proj.scale shape: [16384] +414: model.layers.7.mlp.gate_up_proj.zeroPoint shape: [16384] +415: model.layers.7.post_attention_layernorm.weight shape: [3072] +416: model.layers.7.self_attn.o_proj.8bit_weight shape: [3072, 3072] +417: model.layers.7.self_attn.o_proj.scale shape: [3072] +418: model.layers.7.self_attn.o_proj.zeroPoint shape: [3072] +419: model.layers.7.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +420: model.layers.7.self_attn.qkv_proj.scale shape: [9216] +421: model.layers.7.self_attn.qkv_proj.zeroPoint shape: [9216] +422: model.layers.8.input_layernorm.weight shape: [3072] +423: model.layers.8.mlp.down_proj.8bit_weight shape: [3072, 8192] +424: model.layers.8.mlp.down_proj.scale shape: [3072] +425: model.layers.8.mlp.down_proj.zeroPoint shape: [3072] +426: model.layers.8.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +427: model.layers.8.mlp.gate_up_proj.scale shape: [16384] +428: model.layers.8.mlp.gate_up_proj.zeroPoint shape: [16384] +429: model.layers.8.post_attention_layernorm.weight shape: [3072] +430: model.layers.8.self_attn.o_proj.8bit_weight shape: [3072, 3072] +431: model.layers.8.self_attn.o_proj.scale shape: [3072] +432: model.layers.8.self_attn.o_proj.zeroPoint shape: [3072] +433: model.layers.8.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +434: model.layers.8.self_attn.qkv_proj.scale shape: [9216] +435: model.layers.8.self_attn.qkv_proj.zeroPoint shape: [9216] +436: model.layers.9.input_layernorm.weight shape: [3072] +437: model.layers.9.mlp.down_proj.8bit_weight shape: [3072, 8192] +438: model.layers.9.mlp.down_proj.scale shape: [3072] +439: model.layers.9.mlp.down_proj.zeroPoint shape: [3072] +440: model.layers.9.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] +441: model.layers.9.mlp.gate_up_proj.scale shape: [16384] +442: model.layers.9.mlp.gate_up_proj.zeroPoint shape: [16384] +443: model.layers.9.post_attention_layernorm.weight shape: [3072] +444: model.layers.9.self_attn.o_proj.8bit_weight shape: [3072, 3072] +445: model.layers.9.self_attn.o_proj.scale shape: [3072] +446: model.layers.9.self_attn.o_proj.zeroPoint shape: [3072] +447: model.layers.9.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] +448: model.layers.9.self_attn.qkv_proj.scale shape: [9216] +449: model.layers.9.self_attn.qkv_proj.zeroPoint shape: [9216] +450: model.norm.weight shape: [3072] diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs b/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs index c427114d95..543a415380 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs @@ -34,6 +34,34 @@ public void Phi3Mini4KShapeTest() Approvals.Verify(stateDictStr); } + [Fact] + [UseReporter(typeof(DiffReporter))] + [UseApprovalSubdirectory("Approvals")] + public void Phi3Mini4KInt8QuantizeShapeTest() + { + var model = new Phi3ForCasualLM(Phi3Config.Phi3Mini4kInstruct); + model.ToInt8QuantizeModule(); + var size = model.GetSizeInBytes(); + var stateDictStr = model.PeekShape(); + var sizeInGB = size / 1024 / 1024 / 1024; + sizeInGB.Should().Be(3); + Approvals.Verify(stateDictStr); + } + + [Fact] + [UseReporter(typeof(DiffReporter))] + [UseApprovalSubdirectory("Approvals")] + public void Phi3Mini4KInt4QuantizeShapeTest() + { + var model = new Phi3ForCasualLM(Phi3Config.Phi3Mini4kInstruct); + model.ToInt4QuantizeModule(); + var size = model.GetSizeInBytes(); + var stateDictStr = model.PeekShape(); + var sizeInGB = size / 1024 / 1024 / 1024; + sizeInGB.Should().Be(2); + Approvals.Verify(stateDictStr); + } + [Fact] [UseReporter(typeof(DiffReporter))] [UseApprovalSubdirectory("Approvals")] From 59e5da8850fd69a174661f58a9692d5eea188987 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Fri, 28 Jun 2024 13:27:05 -0700 Subject: [PATCH 12/41] use version string --- .../Microsoft.ML.GenAI.Samples.csproj | 2 +- .../Microsoft.ML.GenAI.Samples/Phi3Mini/AutoGenSample.cs | 2 +- docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs | 6 +++++- eng/Versions.props | 3 +++ src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj | 6 +++--- .../Microsoft.ML.GenAI.Phi.Tests.csproj | 2 +- .../Microsoft.ML.Tokenizers.Tests.csproj | 2 -- 7 files changed, 14 insertions(+), 9 deletions(-) diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj b/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj index e522cff52e..0331a32fc1 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj @@ -14,7 +14,7 @@ - + diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/AutoGenSample.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/AutoGenSample.cs index 17ce52cb10..30cdaeb16f 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/AutoGenSample.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/AutoGenSample.cs @@ -26,7 +26,7 @@ public static async Task RunAsync() torch.manual_seed(1); torch.set_default_dtype(defaultType); var weightFolder = @"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct"; - var pipeline = Utils.LoadPhi3Mini4KFromFolder(weightFolder, device, quantizeToInt4: true); + var pipeline = Utils.LoadPhi3Mini4KFromFolder(weightFolder, device, quantizeToInt8: true); // agent var agent = new Phi3Agent(pipeline, "assistant") diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs index 4a4c108749..15c302d8d0 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs @@ -9,6 +9,7 @@ using static TorchSharp.torch; using TorchSharp; using Microsoft.ML.GenAI.Core.Extension; +using System.Text.Json; namespace Microsoft.ML.GenAI.Samples.Phi3Mini; @@ -49,8 +50,11 @@ public static CausalLMPipeline LoadPhi3Mini4KFro devices: ["cuda:0", "cpu", "disk"], deviceSizeMapInByte: deviceSizeMap); - model = model.ToDynamicLoadingModel(deviceMap, "cuda:0"); + var deviceMapJson = JsonSerializer.Serialize(deviceMap, new JsonSerializerOptions { WriteIndented = true }); + Console.WriteLine($"Device map:"); + Console.WriteLine(deviceMapJson); + model = model.ToDynamicLoadingModel(deviceMap, "cuda:0"); var pipeline = new CausalLMPipeline(tokenizer, model, device); timer.Stop(); Console.WriteLine($"Phi3 loaded in {timer.ElapsedMilliseconds / 1000} s"); diff --git a/eng/Versions.props b/eng/Versions.props index b48e6485bd..0bcb70dab5 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -65,6 +65,9 @@ 2.3.1 0.101.5 2.1.0.1 + 1.4.1 + 0.0.15 + 1.15.0 1.12.4 3.1.2 diff --git a/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj b/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj index f4cdc6333a..9151745902 100644 --- a/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj +++ b/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj @@ -7,14 +7,14 @@ - - + + - + diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj b/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj index 0fbe73d1d0..fc70cdd257 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj @@ -16,7 +16,7 @@ - + diff --git a/test/Microsoft.ML.Tokenizers.Tests/Microsoft.ML.Tokenizers.Tests.csproj b/test/Microsoft.ML.Tokenizers.Tests/Microsoft.ML.Tokenizers.Tests.csproj index 7fb56e82aa..802cae464a 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/Microsoft.ML.Tokenizers.Tests.csproj +++ b/test/Microsoft.ML.Tokenizers.Tests/Microsoft.ML.Tokenizers.Tests.csproj @@ -48,6 +48,4 @@ - - \ No newline at end of file From f9539b8df4cf5550d73da367be5fcf002486e9b8 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Fri, 28 Jun 2024 13:31:31 -0700 Subject: [PATCH 13/41] remove special token from CreatePhi2 API --- src/Microsoft.ML.Tokenizers/Tokenizer.cs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/Microsoft.ML.Tokenizers/Tokenizer.cs b/src/Microsoft.ML.Tokenizers/Tokenizer.cs index 3c01f6ec1e..685a6f8671 100644 --- a/src/Microsoft.ML.Tokenizers/Tokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Tokenizer.cs @@ -531,14 +531,12 @@ public static CodeGen CreatePhi2( string folder, string vocabFile = "vocab.json", string mergesFile = "merges.txt", - string specialTokensFile = "special_tokens_map.json", bool addPrefixSpace = false, bool addBeginOfSentence = false, bool addEndOfSentence = false) { var vocabPath = Path.Combine(folder, vocabFile); var mergesPath = Path.Combine(folder, mergesFile); - var specialTokenMapPath = Path.Combine(folder, specialTokensFile); using var vocabStream = File.OpenRead(vocabPath); using var mergesStream = File.OpenRead(mergesPath); From dbe818713382fc34b805b46d062eb0959b55bcf8 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Fri, 28 Jun 2024 15:31:46 -0700 Subject: [PATCH 14/41] set up quantize sample --- .../Phi3Mini/AutoGenSample.cs | 2 +- .../Phi3Mini/Utils.cs | 30 ++++++++++++++----- .../Pipeline/CausalLMPipeline.cs | 2 +- .../Phi3/Phi3ForCasualLM.cs | 8 +++-- 4 files changed, 31 insertions(+), 11 deletions(-) diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/AutoGenSample.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/AutoGenSample.cs index 30cdaeb16f..392aec674d 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/AutoGenSample.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/AutoGenSample.cs @@ -26,7 +26,7 @@ public static async Task RunAsync() torch.manual_seed(1); torch.set_default_dtype(defaultType); var weightFolder = @"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct"; - var pipeline = Utils.LoadPhi3Mini4KFromFolder(weightFolder, device, quantizeToInt8: true); + var pipeline = Utils.LoadPhi3Mini4KFromFolder(weightFolder, device: device, quantizeToInt8: false); // agent var agent = new Phi3Agent(pipeline, "assistant") diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs index 15c302d8d0..959af7d066 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs @@ -17,6 +17,7 @@ internal static class Utils { public static CausalLMPipeline LoadPhi3Mini4KFromFolder( string weightFolder, + string configName = "config.json", string device = "cuda", int modelSizeOnCudaInGB = 16, int modelSizeOnMemoryInGB = 64, @@ -24,10 +25,12 @@ public static CausalLMPipeline LoadPhi3Mini4KFro bool quantizeToInt8 = false, bool quantizeToInt4 = false) { - var defaultType = ScalarType.Float16; Console.WriteLine("Loading Phi3 from huggingface model weight folder"); + torch.set_default_device("meta"); + var configPath = System.IO.Path.Combine(weightFolder, configName); + var config = JsonSerializer.Deserialize(System.IO.File.ReadAllText(configPath)) ?? throw new ArgumentNullException(nameof(configPath)); var timer = System.Diagnostics.Stopwatch.StartNew(); - var model = Phi3ForCasualLM.FromPretrained(weightFolder, device: device, torchDtype: defaultType, checkPointName: "model.safetensors.index.json"); + var model = new Phi3ForCasualLM(config); var tokenizer = Phi3Tokenizer.FromPretrained(weightFolder); if (quantizeToInt8) @@ -41,23 +44,36 @@ public static CausalLMPipeline LoadPhi3Mini4KFro var deviceSizeMap = new Dictionary { - ["cuda:0"] = modelSizeOnCudaInGB * 1024 * 1024 * 1024, - ["cpu"] = modelSizeOnMemoryInGB * 1024 * 1024 * 1024, - ["disk"] = modelSizeOnDiskInGB * 1024 * 1024 * 1024, + ["cuda"] = modelSizeOnCudaInGB * 1L * 1024 * 1024 * 1024, + ["cpu"] = modelSizeOnMemoryInGB * 1L * 1024 * 1024 * 1024, + ["disk"] = modelSizeOnDiskInGB * 1L * 1024 * 1024 * 1024, }; var deviceMap = model.InferDeviceMapForEachLayer( - devices: ["cuda:0", "cpu", "disk"], + devices: ["cuda", "cpu", "disk"], deviceSizeMapInByte: deviceSizeMap); var deviceMapJson = JsonSerializer.Serialize(deviceMap, new JsonSerializerOptions { WriteIndented = true }); Console.WriteLine($"Device map:"); Console.WriteLine(deviceMapJson); - model = model.ToDynamicLoadingModel(deviceMap, "cuda:0"); + // load weight + torch.set_default_device("cpu"); + model = new Phi3ForCasualLM(config); + model.LoadSafeTensors(weightFolder); + if (quantizeToInt8) + { + model.ToInt8QuantizeModule(); + } + else if (quantizeToInt4) + { + model.ToInt4QuantizeModule(); + } + model = model.ToDynamicLoadingModel(deviceMap, "cuda"); var pipeline = new CausalLMPipeline(tokenizer, model, device); timer.Stop(); Console.WriteLine($"Phi3 loaded in {timer.ElapsedMilliseconds / 1000} s"); + torch.set_default_device(device); return pipeline; } diff --git a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs index 7122878a9b..57dc7f88c7 100644 --- a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs +++ b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs @@ -237,7 +237,7 @@ public virtual IEnumerable GenerateStreaming( using var newScope = NewDisposeScope(); var inputIds = this.Tokenizer.EncodeToIds(prompt); var inputTensor = torch.tensor(inputIds.ToArray(), dtype: ScalarType.Int64, device: this.Device).unsqueeze(0); - var attentionMask = torch.ones_like(inputTensor); + var attentionMask = torch.ones_like(inputTensor, device: this.Device); // set up stop token ids // stop token ids: [[eosId], [stopSequence1], [stopSequence2], ...] // when causal language model generates tokens, it will stop when it generates any token in stopSequences diff --git a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ForCasualLM.cs b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ForCasualLM.cs index 8ab7ecc652..41b2d970fd 100644 --- a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ForCasualLM.cs +++ b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ForCasualLM.cs @@ -59,11 +59,15 @@ public static Phi3ForCasualLM FromPretrained( var modelConfig = JsonSerializer.Deserialize(File.ReadAllText(config)) ?? throw new ArgumentNullException(nameof(config)); modelConfig.DType = torchDtype; var phi = new Phi3ForCasualLM(modelConfig); - var loadedParameters = new Dictionary(); - phi.load_checkpoint(path: modelFolder, checkpointName: checkPointName, strict: false, loadedParameters: loadedParameters); + phi.LoadSafeTensors(modelFolder, checkPointName); phi = phi.to(device); phi.eval(); return phi; } + + public void LoadSafeTensors(string modelFolder, string checkPointName = "model.safetensors.index.json") + { + this.load_checkpoint(path: modelFolder, checkpointName: checkPointName, strict: false, useTqdm: false); + } } From ccaddfeb2276a34ccdc505da4e7f1ee4f56bbf2c Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Fri, 28 Jun 2024 16:27:29 -0700 Subject: [PATCH 15/41] initialize linear with zeros --- .../Phi3Mini/AutoGenSample.cs | 2 +- .../Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs | 14 ++++++++++++++ src/Microsoft.ML.GenAI.Core/Module/GenAILinear.cs | 6 ++---- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/AutoGenSample.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/AutoGenSample.cs index 392aec674d..5fafba2e11 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/AutoGenSample.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/AutoGenSample.cs @@ -26,7 +26,7 @@ public static async Task RunAsync() torch.manual_seed(1); torch.set_default_dtype(defaultType); var weightFolder = @"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct"; - var pipeline = Utils.LoadPhi3Mini4KFromFolder(weightFolder, device: device, quantizeToInt8: false); + var pipeline = Utils.LoadPhi3Mini4KFromFolder(weightFolder, device: device, quantizeToInt8: true); // agent var agent = new Phi3Agent(pipeline, "assistant") diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs index 959af7d066..c75cdb9e92 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs @@ -59,8 +59,20 @@ public static CausalLMPipeline LoadPhi3Mini4KFro // load weight torch.set_default_device("cpu"); + + Console.WriteLine("Start loading"); + timer = System.Diagnostics.Stopwatch.StartNew(); model = new Phi3ForCasualLM(config); + timer.Stop(); + Console.WriteLine($"Phi3 model created in {timer.ElapsedMilliseconds / 1000} s"); + + timer = System.Diagnostics.Stopwatch.StartNew(); model.LoadSafeTensors(weightFolder); + timer.Stop(); + Console.WriteLine($"Phi3 weight loaded in {timer.ElapsedMilliseconds / 1000} s"); + + timer = System.Diagnostics.Stopwatch.StartNew(); + Console.WriteLine("Start quantizing if needed"); if (quantizeToInt8) { model.ToInt8QuantizeModule(); @@ -69,6 +81,8 @@ public static CausalLMPipeline LoadPhi3Mini4KFro { model.ToInt4QuantizeModule(); } + Console.WriteLine("Quantizing done"); + model = model.ToDynamicLoadingModel(deviceMap, "cuda"); var pipeline = new CausalLMPipeline(tokenizer, model, device); timer.Stop(); diff --git a/src/Microsoft.ML.GenAI.Core/Module/GenAILinear.cs b/src/Microsoft.ML.GenAI.Core/Module/GenAILinear.cs index c59fd9f38d..3c85a5a6f9 100644 --- a/src/Microsoft.ML.GenAI.Core/Module/GenAILinear.cs +++ b/src/Microsoft.ML.GenAI.Core/Module/GenAILinear.cs @@ -21,14 +21,12 @@ public GenAILinear(int inFeatures, int outFeatures, bool hasBias = true, ScalarT this._inFeatures = inFeatures; this._outFeatures = outFeatures; device ??= torch.get_default_device().ToString(); - this.weight = torch.randn(outFeatures, inFeatures, dtype: dtype, device: device); + this.weight = torch.zeros(outFeatures, inFeatures, dtype: dtype, device: device); if (hasBias) { - this.bias = torch.randn(outFeatures, dtype: dtype, device: device); + this.bias = torch.zeros(outFeatures, dtype: dtype, device: device); } - - this.RegisterComponents(); } #pragma warning disable MSML_GeneralName // This name should be PascalCased From 151fa864466b33838b03cafe4d379be437b36888 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Mon, 1 Jul 2024 12:50:51 -0700 Subject: [PATCH 16/41] update sample --- .../Phi3Mini/Utils.cs | 29 ++++++++++++------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs index c75cdb9e92..f8147fa822 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs @@ -71,22 +71,29 @@ public static CausalLMPipeline LoadPhi3Mini4KFro timer.Stop(); Console.WriteLine($"Phi3 weight loaded in {timer.ElapsedMilliseconds / 1000} s"); - timer = System.Diagnostics.Stopwatch.StartNew(); - Console.WriteLine("Start quantizing if needed"); - if (quantizeToInt8) - { - model.ToInt8QuantizeModule(); - } - else if (quantizeToInt4) + if (quantizeToInt8 || quantizeToInt4) { - model.ToInt4QuantizeModule(); + timer = System.Diagnostics.Stopwatch.StartNew(); + Console.WriteLine("Start quantizing if needed"); + if (quantizeToInt8) + { + model.ToInt8QuantizeModule(); + } + else if (quantizeToInt4) + { + model.ToInt4QuantizeModule(); + } + Console.WriteLine("Quantizing done"); + timer.Stop(); + Console.WriteLine($"Quantizing done in {timer.ElapsedMilliseconds / 1000} s"); } - Console.WriteLine("Quantizing done"); + timer = System.Diagnostics.Stopwatch.StartNew(); + Console.WriteLine($"Start loading to device: {device}"); model = model.ToDynamicLoadingModel(deviceMap, "cuda"); - var pipeline = new CausalLMPipeline(tokenizer, model, device); timer.Stop(); - Console.WriteLine($"Phi3 loaded in {timer.ElapsedMilliseconds / 1000} s"); + Console.WriteLine($"Phi3 loaded to device: {device} in {timer.ElapsedMilliseconds / 1000} s"); + var pipeline = new CausalLMPipeline(tokenizer, model, device); torch.set_default_device(device); return pipeline; From b4e5d844412ec458aaddd586cc70935089f8894f Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Mon, 1 Jul 2024 12:59:21 -0700 Subject: [PATCH 17/41] add 6.0 to targetframework --- .../Phi3Mini/AutoGenSample.cs | 2 +- .../Phi3Mini/SemanticKernelSample.cs | 20 ++++++++----------- .../Microsoft.ML.GenAI.Samples/Program.cs | 2 +- .../Microsoft.ML.GenAI.Core.csproj | 2 +- .../Microsoft.ML.GenAI.Phi.csproj | 2 +- .../Microsoft.ML.GenAI.Core.Tests.csproj | 2 +- .../Microsoft.ML.GenAI.Phi.Tests.csproj | 2 +- 7 files changed, 14 insertions(+), 18 deletions(-) diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/AutoGenSample.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/AutoGenSample.cs index 5fafba2e11..392aec674d 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/AutoGenSample.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/AutoGenSample.cs @@ -26,7 +26,7 @@ public static async Task RunAsync() torch.manual_seed(1); torch.set_default_dtype(defaultType); var weightFolder = @"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct"; - var pipeline = Utils.LoadPhi3Mini4KFromFolder(weightFolder, device: device, quantizeToInt8: true); + var pipeline = Utils.LoadPhi3Mini4KFromFolder(weightFolder, device: device, quantizeToInt8: false); // agent var agent = new Phi3Agent(pipeline, "assistant") diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/SemanticKernelSample.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/SemanticKernelSample.cs index 0a3016dcde..b2ea3d610e 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/SemanticKernelSample.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/SemanticKernelSample.cs @@ -1,14 +1,8 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; -using Microsoft.ML.GenAI.Phi; -using static TorchSharp.torch; -using TorchSharp; +using Microsoft.ML.GenAI.Phi.Extension; using Microsoft.SemanticKernel; -using Microsoft.ML.GenAI.Phi.Extension; using Microsoft.SemanticKernel.ChatCompletion; +using TorchSharp; +using static TorchSharp.torch; namespace Microsoft.ML.GenAI.Samples.Phi3Mini; @@ -26,7 +20,7 @@ public static async Task RunChatCompletionSample() torch.manual_seed(1); torch.set_default_dtype(defaultType); var weightFolder = @"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct"; - var pipeline = Utils.LoadPhi3Mini4KFromFolder(weightFolder, device); + var pipeline = Utils.LoadPhi3Mini4KFromFolder(weightFolder, device: device); var kernel = Kernel.CreateBuilder() @@ -37,8 +31,10 @@ public static async Task RunChatCompletionSample() chatHistory.AddSystemMessage("you are a helpful assistant"); chatHistory.AddUserMessage("write a C# program to calculate the factorial of a number"); - var response = await chatService.GetChatMessageContentAsync(chatHistory); - Console.WriteLine(response); + await foreach (var response in chatService.GetStreamingChatMessageContentsAsync(chatHistory)) + { + Console.WriteLine(response); + } } public static async Task RunTextGenerationSample() diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs index 5e4355e595..1560bad306 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs @@ -1,4 +1,4 @@ // See https://aka.ms/new-console-template for more information using Microsoft.ML.GenAI.Samples.Phi3Mini; -await AutoGenSample.RunAsync(); +await SemanticKernelSample.RunChatCompletionSample(); diff --git a/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj b/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj index 9f358d9914..c2a8c1319c 100644 --- a/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj +++ b/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj @@ -1,7 +1,7 @@  - net8.0 + net6.0;net8.0 false enable preview diff --git a/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj b/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj index 9151745902..0c51883e6c 100644 --- a/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj +++ b/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj @@ -1,7 +1,7 @@  - net8.0 + net6.0;net8.0 enable enable diff --git a/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj b/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj index 7e3d9a9943..bc0eea0bb6 100644 --- a/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj +++ b/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj @@ -1,7 +1,7 @@  - net8.0 + net6.0 enable enable diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj b/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj index fc70cdd257..bf4580ee9f 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj @@ -1,7 +1,7 @@  - net8.0 + net6.0 enable enable From b9604a0ced1b3befc4e71f226a30dd569e61cb47 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Mon, 1 Jul 2024 13:31:25 -0700 Subject: [PATCH 18/41] fix tests --- .../Phi3Mini/SemanticKernelSample.cs | 2 +- .../Module/GenAILinear.cs | 2 + ...i3Mini4KInt8QuantizeShapeTest.received.txt | 451 ------------------ 3 files changed, 3 insertions(+), 452 deletions(-) delete mode 100644 test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Mini4KInt8QuantizeShapeTest.received.txt diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/SemanticKernelSample.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/SemanticKernelSample.cs index b2ea3d610e..8ce0ead983 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/SemanticKernelSample.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/SemanticKernelSample.cs @@ -33,7 +33,7 @@ public static async Task RunChatCompletionSample() await foreach (var response in chatService.GetStreamingChatMessageContentsAsync(chatHistory)) { - Console.WriteLine(response); + Console.Write(response); } } diff --git a/src/Microsoft.ML.GenAI.Core/Module/GenAILinear.cs b/src/Microsoft.ML.GenAI.Core/Module/GenAILinear.cs index 3c85a5a6f9..77bcadeb82 100644 --- a/src/Microsoft.ML.GenAI.Core/Module/GenAILinear.cs +++ b/src/Microsoft.ML.GenAI.Core/Module/GenAILinear.cs @@ -27,6 +27,8 @@ public GenAILinear(int inFeatures, int outFeatures, bool hasBias = true, ScalarT { this.bias = torch.zeros(outFeatures, dtype: dtype, device: device); } + + base.RegisterComponents(); } #pragma warning disable MSML_GeneralName // This name should be PascalCased diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Mini4KInt8QuantizeShapeTest.received.txt b/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Mini4KInt8QuantizeShapeTest.received.txt deleted file mode 100644 index d9c51c94a9..0000000000 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.Phi3Mini4KInt8QuantizeShapeTest.received.txt +++ /dev/null @@ -1,451 +0,0 @@ -0: lm_head.weight shape: [32064, 3072] -1: model.embed_tokens.weight shape: [32064, 3072] -2: model.layers.0.input_layernorm.weight shape: [3072] -3: model.layers.0.mlp.down_proj.8bit_weight shape: [3072, 8192] -4: model.layers.0.mlp.down_proj.scale shape: [3072] -5: model.layers.0.mlp.down_proj.zeroPoint shape: [3072] -6: model.layers.0.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] -7: model.layers.0.mlp.gate_up_proj.scale shape: [16384] -8: model.layers.0.mlp.gate_up_proj.zeroPoint shape: [16384] -9: model.layers.0.post_attention_layernorm.weight shape: [3072] -10: model.layers.0.self_attn.o_proj.8bit_weight shape: [3072, 3072] -11: model.layers.0.self_attn.o_proj.scale shape: [3072] -12: model.layers.0.self_attn.o_proj.zeroPoint shape: [3072] -13: model.layers.0.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] -14: model.layers.0.self_attn.qkv_proj.scale shape: [9216] -15: model.layers.0.self_attn.qkv_proj.zeroPoint shape: [9216] -16: model.layers.1.input_layernorm.weight shape: [3072] -17: model.layers.1.mlp.down_proj.8bit_weight shape: [3072, 8192] -18: model.layers.1.mlp.down_proj.scale shape: [3072] -19: model.layers.1.mlp.down_proj.zeroPoint shape: [3072] -20: model.layers.1.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] -21: model.layers.1.mlp.gate_up_proj.scale shape: [16384] -22: model.layers.1.mlp.gate_up_proj.zeroPoint shape: [16384] -23: model.layers.1.post_attention_layernorm.weight shape: [3072] -24: model.layers.1.self_attn.o_proj.8bit_weight shape: [3072, 3072] -25: model.layers.1.self_attn.o_proj.scale shape: [3072] -26: model.layers.1.self_attn.o_proj.zeroPoint shape: [3072] -27: model.layers.1.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] -28: model.layers.1.self_attn.qkv_proj.scale shape: [9216] -29: model.layers.1.self_attn.qkv_proj.zeroPoint shape: [9216] -30: model.layers.10.input_layernorm.weight shape: [3072] -31: model.layers.10.mlp.down_proj.8bit_weight shape: [3072, 8192] -32: model.layers.10.mlp.down_proj.scale shape: [3072] -33: model.layers.10.mlp.down_proj.zeroPoint shape: [3072] -34: model.layers.10.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] -35: model.layers.10.mlp.gate_up_proj.scale shape: [16384] -36: model.layers.10.mlp.gate_up_proj.zeroPoint shape: [16384] -37: model.layers.10.post_attention_layernorm.weight shape: [3072] -38: model.layers.10.self_attn.o_proj.8bit_weight shape: [3072, 3072] -39: model.layers.10.self_attn.o_proj.scale shape: [3072] -40: model.layers.10.self_attn.o_proj.zeroPoint shape: [3072] -41: model.layers.10.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] -42: model.layers.10.self_attn.qkv_proj.scale shape: [9216] -43: model.layers.10.self_attn.qkv_proj.zeroPoint shape: [9216] -44: model.layers.11.input_layernorm.weight shape: [3072] -45: model.layers.11.mlp.down_proj.8bit_weight shape: [3072, 8192] -46: model.layers.11.mlp.down_proj.scale shape: [3072] -47: model.layers.11.mlp.down_proj.zeroPoint shape: [3072] -48: model.layers.11.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] -49: model.layers.11.mlp.gate_up_proj.scale shape: [16384] -50: model.layers.11.mlp.gate_up_proj.zeroPoint shape: [16384] -51: model.layers.11.post_attention_layernorm.weight shape: [3072] -52: model.layers.11.self_attn.o_proj.8bit_weight shape: [3072, 3072] -53: model.layers.11.self_attn.o_proj.scale shape: [3072] -54: model.layers.11.self_attn.o_proj.zeroPoint shape: [3072] -55: model.layers.11.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] -56: model.layers.11.self_attn.qkv_proj.scale shape: [9216] -57: model.layers.11.self_attn.qkv_proj.zeroPoint shape: [9216] -58: model.layers.12.input_layernorm.weight shape: [3072] -59: model.layers.12.mlp.down_proj.8bit_weight shape: [3072, 8192] -60: model.layers.12.mlp.down_proj.scale shape: [3072] -61: model.layers.12.mlp.down_proj.zeroPoint shape: [3072] -62: model.layers.12.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] -63: model.layers.12.mlp.gate_up_proj.scale shape: [16384] -64: model.layers.12.mlp.gate_up_proj.zeroPoint shape: [16384] -65: model.layers.12.post_attention_layernorm.weight shape: [3072] -66: model.layers.12.self_attn.o_proj.8bit_weight shape: [3072, 3072] -67: model.layers.12.self_attn.o_proj.scale shape: [3072] -68: model.layers.12.self_attn.o_proj.zeroPoint shape: [3072] -69: model.layers.12.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] -70: model.layers.12.self_attn.qkv_proj.scale shape: [9216] -71: model.layers.12.self_attn.qkv_proj.zeroPoint shape: [9216] -72: model.layers.13.input_layernorm.weight shape: [3072] -73: model.layers.13.mlp.down_proj.8bit_weight shape: [3072, 8192] -74: model.layers.13.mlp.down_proj.scale shape: [3072] -75: model.layers.13.mlp.down_proj.zeroPoint shape: [3072] -76: model.layers.13.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] -77: model.layers.13.mlp.gate_up_proj.scale shape: [16384] -78: model.layers.13.mlp.gate_up_proj.zeroPoint shape: [16384] -79: model.layers.13.post_attention_layernorm.weight shape: [3072] -80: model.layers.13.self_attn.o_proj.8bit_weight shape: [3072, 3072] -81: model.layers.13.self_attn.o_proj.scale shape: [3072] -82: model.layers.13.self_attn.o_proj.zeroPoint shape: [3072] -83: model.layers.13.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] -84: model.layers.13.self_attn.qkv_proj.scale shape: [9216] -85: model.layers.13.self_attn.qkv_proj.zeroPoint shape: [9216] -86: model.layers.14.input_layernorm.weight shape: [3072] -87: model.layers.14.mlp.down_proj.8bit_weight shape: [3072, 8192] -88: model.layers.14.mlp.down_proj.scale shape: [3072] -89: model.layers.14.mlp.down_proj.zeroPoint shape: [3072] -90: model.layers.14.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] -91: model.layers.14.mlp.gate_up_proj.scale shape: [16384] -92: model.layers.14.mlp.gate_up_proj.zeroPoint shape: [16384] -93: model.layers.14.post_attention_layernorm.weight shape: [3072] -94: model.layers.14.self_attn.o_proj.8bit_weight shape: [3072, 3072] -95: model.layers.14.self_attn.o_proj.scale shape: [3072] -96: model.layers.14.self_attn.o_proj.zeroPoint shape: [3072] -97: model.layers.14.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] -98: model.layers.14.self_attn.qkv_proj.scale shape: [9216] -99: model.layers.14.self_attn.qkv_proj.zeroPoint shape: [9216] -100: model.layers.15.input_layernorm.weight shape: [3072] -101: model.layers.15.mlp.down_proj.8bit_weight shape: [3072, 8192] -102: model.layers.15.mlp.down_proj.scale shape: [3072] -103: model.layers.15.mlp.down_proj.zeroPoint shape: [3072] -104: model.layers.15.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] -105: model.layers.15.mlp.gate_up_proj.scale shape: [16384] -106: model.layers.15.mlp.gate_up_proj.zeroPoint shape: [16384] -107: model.layers.15.post_attention_layernorm.weight shape: [3072] -108: model.layers.15.self_attn.o_proj.8bit_weight shape: [3072, 3072] -109: model.layers.15.self_attn.o_proj.scale shape: [3072] -110: model.layers.15.self_attn.o_proj.zeroPoint shape: [3072] -111: model.layers.15.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] -112: model.layers.15.self_attn.qkv_proj.scale shape: [9216] -113: model.layers.15.self_attn.qkv_proj.zeroPoint shape: [9216] -114: model.layers.16.input_layernorm.weight shape: [3072] -115: model.layers.16.mlp.down_proj.8bit_weight shape: [3072, 8192] -116: model.layers.16.mlp.down_proj.scale shape: [3072] -117: model.layers.16.mlp.down_proj.zeroPoint shape: [3072] -118: model.layers.16.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] -119: model.layers.16.mlp.gate_up_proj.scale shape: [16384] -120: model.layers.16.mlp.gate_up_proj.zeroPoint shape: [16384] -121: model.layers.16.post_attention_layernorm.weight shape: [3072] -122: model.layers.16.self_attn.o_proj.8bit_weight shape: [3072, 3072] -123: model.layers.16.self_attn.o_proj.scale shape: [3072] -124: model.layers.16.self_attn.o_proj.zeroPoint shape: [3072] -125: model.layers.16.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] -126: model.layers.16.self_attn.qkv_proj.scale shape: [9216] -127: model.layers.16.self_attn.qkv_proj.zeroPoint shape: [9216] -128: model.layers.17.input_layernorm.weight shape: [3072] -129: model.layers.17.mlp.down_proj.8bit_weight shape: [3072, 8192] -130: model.layers.17.mlp.down_proj.scale shape: [3072] -131: model.layers.17.mlp.down_proj.zeroPoint shape: [3072] -132: model.layers.17.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] -133: model.layers.17.mlp.gate_up_proj.scale shape: [16384] -134: model.layers.17.mlp.gate_up_proj.zeroPoint shape: [16384] -135: model.layers.17.post_attention_layernorm.weight shape: [3072] -136: model.layers.17.self_attn.o_proj.8bit_weight shape: [3072, 3072] -137: model.layers.17.self_attn.o_proj.scale shape: [3072] -138: model.layers.17.self_attn.o_proj.zeroPoint shape: [3072] -139: model.layers.17.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] -140: model.layers.17.self_attn.qkv_proj.scale shape: [9216] -141: model.layers.17.self_attn.qkv_proj.zeroPoint shape: [9216] -142: model.layers.18.input_layernorm.weight shape: [3072] -143: model.layers.18.mlp.down_proj.8bit_weight shape: [3072, 8192] -144: model.layers.18.mlp.down_proj.scale shape: [3072] -145: model.layers.18.mlp.down_proj.zeroPoint shape: [3072] -146: model.layers.18.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] -147: model.layers.18.mlp.gate_up_proj.scale shape: [16384] -148: model.layers.18.mlp.gate_up_proj.zeroPoint shape: [16384] -149: model.layers.18.post_attention_layernorm.weight shape: [3072] -150: model.layers.18.self_attn.o_proj.8bit_weight shape: [3072, 3072] -151: model.layers.18.self_attn.o_proj.scale shape: [3072] -152: model.layers.18.self_attn.o_proj.zeroPoint shape: [3072] -153: model.layers.18.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] -154: model.layers.18.self_attn.qkv_proj.scale shape: [9216] -155: model.layers.18.self_attn.qkv_proj.zeroPoint shape: [9216] -156: model.layers.19.input_layernorm.weight shape: [3072] -157: model.layers.19.mlp.down_proj.8bit_weight shape: [3072, 8192] -158: model.layers.19.mlp.down_proj.scale shape: [3072] -159: model.layers.19.mlp.down_proj.zeroPoint shape: [3072] -160: model.layers.19.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] -161: model.layers.19.mlp.gate_up_proj.scale shape: [16384] -162: model.layers.19.mlp.gate_up_proj.zeroPoint shape: [16384] -163: model.layers.19.post_attention_layernorm.weight shape: [3072] -164: model.layers.19.self_attn.o_proj.8bit_weight shape: [3072, 3072] -165: model.layers.19.self_attn.o_proj.scale shape: [3072] -166: model.layers.19.self_attn.o_proj.zeroPoint shape: [3072] -167: model.layers.19.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] -168: model.layers.19.self_attn.qkv_proj.scale shape: [9216] -169: model.layers.19.self_attn.qkv_proj.zeroPoint shape: [9216] -170: model.layers.2.input_layernorm.weight shape: [3072] -171: model.layers.2.mlp.down_proj.8bit_weight shape: [3072, 8192] -172: model.layers.2.mlp.down_proj.scale shape: [3072] -173: model.layers.2.mlp.down_proj.zeroPoint shape: [3072] -174: model.layers.2.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] -175: model.layers.2.mlp.gate_up_proj.scale shape: [16384] -176: model.layers.2.mlp.gate_up_proj.zeroPoint shape: [16384] -177: model.layers.2.post_attention_layernorm.weight shape: [3072] -178: model.layers.2.self_attn.o_proj.8bit_weight shape: [3072, 3072] -179: model.layers.2.self_attn.o_proj.scale shape: [3072] -180: model.layers.2.self_attn.o_proj.zeroPoint shape: [3072] -181: model.layers.2.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] -182: model.layers.2.self_attn.qkv_proj.scale shape: [9216] -183: model.layers.2.self_attn.qkv_proj.zeroPoint shape: [9216] -184: model.layers.20.input_layernorm.weight shape: [3072] -185: model.layers.20.mlp.down_proj.8bit_weight shape: [3072, 8192] -186: model.layers.20.mlp.down_proj.scale shape: [3072] -187: model.layers.20.mlp.down_proj.zeroPoint shape: [3072] -188: model.layers.20.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] -189: model.layers.20.mlp.gate_up_proj.scale shape: [16384] -190: model.layers.20.mlp.gate_up_proj.zeroPoint shape: [16384] -191: model.layers.20.post_attention_layernorm.weight shape: [3072] -192: model.layers.20.self_attn.o_proj.8bit_weight shape: [3072, 3072] -193: model.layers.20.self_attn.o_proj.scale shape: [3072] -194: model.layers.20.self_attn.o_proj.zeroPoint shape: [3072] -195: model.layers.20.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] -196: model.layers.20.self_attn.qkv_proj.scale shape: [9216] -197: model.layers.20.self_attn.qkv_proj.zeroPoint shape: [9216] -198: model.layers.21.input_layernorm.weight shape: [3072] -199: model.layers.21.mlp.down_proj.8bit_weight shape: [3072, 8192] -200: model.layers.21.mlp.down_proj.scale shape: [3072] -201: model.layers.21.mlp.down_proj.zeroPoint shape: [3072] -202: model.layers.21.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] -203: model.layers.21.mlp.gate_up_proj.scale shape: [16384] -204: model.layers.21.mlp.gate_up_proj.zeroPoint shape: [16384] -205: model.layers.21.post_attention_layernorm.weight shape: [3072] -206: model.layers.21.self_attn.o_proj.8bit_weight shape: [3072, 3072] -207: model.layers.21.self_attn.o_proj.scale shape: [3072] -208: model.layers.21.self_attn.o_proj.zeroPoint shape: [3072] -209: model.layers.21.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] -210: model.layers.21.self_attn.qkv_proj.scale shape: [9216] -211: model.layers.21.self_attn.qkv_proj.zeroPoint shape: [9216] -212: model.layers.22.input_layernorm.weight shape: [3072] -213: model.layers.22.mlp.down_proj.8bit_weight shape: [3072, 8192] -214: model.layers.22.mlp.down_proj.scale shape: [3072] -215: model.layers.22.mlp.down_proj.zeroPoint shape: [3072] -216: model.layers.22.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] -217: model.layers.22.mlp.gate_up_proj.scale shape: [16384] -218: model.layers.22.mlp.gate_up_proj.zeroPoint shape: [16384] -219: model.layers.22.post_attention_layernorm.weight shape: [3072] -220: model.layers.22.self_attn.o_proj.8bit_weight shape: [3072, 3072] -221: model.layers.22.self_attn.o_proj.scale shape: [3072] -222: model.layers.22.self_attn.o_proj.zeroPoint shape: [3072] -223: model.layers.22.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] -224: model.layers.22.self_attn.qkv_proj.scale shape: [9216] -225: model.layers.22.self_attn.qkv_proj.zeroPoint shape: [9216] -226: model.layers.23.input_layernorm.weight shape: [3072] -227: model.layers.23.mlp.down_proj.8bit_weight shape: [3072, 8192] -228: model.layers.23.mlp.down_proj.scale shape: [3072] -229: model.layers.23.mlp.down_proj.zeroPoint shape: [3072] -230: model.layers.23.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] -231: model.layers.23.mlp.gate_up_proj.scale shape: [16384] -232: model.layers.23.mlp.gate_up_proj.zeroPoint shape: [16384] -233: model.layers.23.post_attention_layernorm.weight shape: [3072] -234: model.layers.23.self_attn.o_proj.8bit_weight shape: [3072, 3072] -235: model.layers.23.self_attn.o_proj.scale shape: [3072] -236: model.layers.23.self_attn.o_proj.zeroPoint shape: [3072] -237: model.layers.23.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] -238: model.layers.23.self_attn.qkv_proj.scale shape: [9216] -239: model.layers.23.self_attn.qkv_proj.zeroPoint shape: [9216] -240: model.layers.24.input_layernorm.weight shape: [3072] -241: model.layers.24.mlp.down_proj.8bit_weight shape: [3072, 8192] -242: model.layers.24.mlp.down_proj.scale shape: [3072] -243: model.layers.24.mlp.down_proj.zeroPoint shape: [3072] -244: model.layers.24.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] -245: model.layers.24.mlp.gate_up_proj.scale shape: [16384] -246: model.layers.24.mlp.gate_up_proj.zeroPoint shape: [16384] -247: model.layers.24.post_attention_layernorm.weight shape: [3072] -248: model.layers.24.self_attn.o_proj.8bit_weight shape: [3072, 3072] -249: model.layers.24.self_attn.o_proj.scale shape: [3072] -250: model.layers.24.self_attn.o_proj.zeroPoint shape: [3072] -251: model.layers.24.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] -252: model.layers.24.self_attn.qkv_proj.scale shape: [9216] -253: model.layers.24.self_attn.qkv_proj.zeroPoint shape: [9216] -254: model.layers.25.input_layernorm.weight shape: [3072] -255: model.layers.25.mlp.down_proj.8bit_weight shape: [3072, 8192] -256: model.layers.25.mlp.down_proj.scale shape: [3072] -257: model.layers.25.mlp.down_proj.zeroPoint shape: [3072] -258: model.layers.25.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] -259: model.layers.25.mlp.gate_up_proj.scale shape: [16384] -260: model.layers.25.mlp.gate_up_proj.zeroPoint shape: [16384] -261: model.layers.25.post_attention_layernorm.weight shape: [3072] -262: model.layers.25.self_attn.o_proj.8bit_weight shape: [3072, 3072] -263: model.layers.25.self_attn.o_proj.scale shape: [3072] -264: model.layers.25.self_attn.o_proj.zeroPoint shape: [3072] -265: model.layers.25.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] -266: model.layers.25.self_attn.qkv_proj.scale shape: [9216] -267: model.layers.25.self_attn.qkv_proj.zeroPoint shape: [9216] -268: model.layers.26.input_layernorm.weight shape: [3072] -269: model.layers.26.mlp.down_proj.8bit_weight shape: [3072, 8192] -270: model.layers.26.mlp.down_proj.scale shape: [3072] -271: model.layers.26.mlp.down_proj.zeroPoint shape: [3072] -272: model.layers.26.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] -273: model.layers.26.mlp.gate_up_proj.scale shape: [16384] -274: model.layers.26.mlp.gate_up_proj.zeroPoint shape: [16384] -275: model.layers.26.post_attention_layernorm.weight shape: [3072] -276: model.layers.26.self_attn.o_proj.8bit_weight shape: [3072, 3072] -277: model.layers.26.self_attn.o_proj.scale shape: [3072] -278: model.layers.26.self_attn.o_proj.zeroPoint shape: [3072] -279: model.layers.26.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] -280: model.layers.26.self_attn.qkv_proj.scale shape: [9216] -281: model.layers.26.self_attn.qkv_proj.zeroPoint shape: [9216] -282: model.layers.27.input_layernorm.weight shape: [3072] -283: model.layers.27.mlp.down_proj.8bit_weight shape: [3072, 8192] -284: model.layers.27.mlp.down_proj.scale shape: [3072] -285: model.layers.27.mlp.down_proj.zeroPoint shape: [3072] -286: model.layers.27.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] -287: model.layers.27.mlp.gate_up_proj.scale shape: [16384] -288: model.layers.27.mlp.gate_up_proj.zeroPoint shape: [16384] -289: model.layers.27.post_attention_layernorm.weight shape: [3072] -290: model.layers.27.self_attn.o_proj.8bit_weight shape: [3072, 3072] -291: model.layers.27.self_attn.o_proj.scale shape: [3072] -292: model.layers.27.self_attn.o_proj.zeroPoint shape: [3072] -293: model.layers.27.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] -294: model.layers.27.self_attn.qkv_proj.scale shape: [9216] -295: model.layers.27.self_attn.qkv_proj.zeroPoint shape: [9216] -296: model.layers.28.input_layernorm.weight shape: [3072] -297: model.layers.28.mlp.down_proj.8bit_weight shape: [3072, 8192] -298: model.layers.28.mlp.down_proj.scale shape: [3072] -299: model.layers.28.mlp.down_proj.zeroPoint shape: [3072] -300: model.layers.28.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] -301: model.layers.28.mlp.gate_up_proj.scale shape: [16384] -302: model.layers.28.mlp.gate_up_proj.zeroPoint shape: [16384] -303: model.layers.28.post_attention_layernorm.weight shape: [3072] -304: model.layers.28.self_attn.o_proj.8bit_weight shape: [3072, 3072] -305: model.layers.28.self_attn.o_proj.scale shape: [3072] -306: model.layers.28.self_attn.o_proj.zeroPoint shape: [3072] -307: model.layers.28.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] -308: model.layers.28.self_attn.qkv_proj.scale shape: [9216] -309: model.layers.28.self_attn.qkv_proj.zeroPoint shape: [9216] -310: model.layers.29.input_layernorm.weight shape: [3072] -311: model.layers.29.mlp.down_proj.8bit_weight shape: [3072, 8192] -312: model.layers.29.mlp.down_proj.scale shape: [3072] -313: model.layers.29.mlp.down_proj.zeroPoint shape: [3072] -314: model.layers.29.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] -315: model.layers.29.mlp.gate_up_proj.scale shape: [16384] -316: model.layers.29.mlp.gate_up_proj.zeroPoint shape: [16384] -317: model.layers.29.post_attention_layernorm.weight shape: [3072] -318: model.layers.29.self_attn.o_proj.8bit_weight shape: [3072, 3072] -319: model.layers.29.self_attn.o_proj.scale shape: [3072] -320: model.layers.29.self_attn.o_proj.zeroPoint shape: [3072] -321: model.layers.29.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] -322: model.layers.29.self_attn.qkv_proj.scale shape: [9216] -323: model.layers.29.self_attn.qkv_proj.zeroPoint shape: [9216] -324: model.layers.3.input_layernorm.weight shape: [3072] -325: model.layers.3.mlp.down_proj.8bit_weight shape: [3072, 8192] -326: model.layers.3.mlp.down_proj.scale shape: [3072] -327: model.layers.3.mlp.down_proj.zeroPoint shape: [3072] -328: model.layers.3.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] -329: model.layers.3.mlp.gate_up_proj.scale shape: [16384] -330: model.layers.3.mlp.gate_up_proj.zeroPoint shape: [16384] -331: model.layers.3.post_attention_layernorm.weight shape: [3072] -332: model.layers.3.self_attn.o_proj.8bit_weight shape: [3072, 3072] -333: model.layers.3.self_attn.o_proj.scale shape: [3072] -334: model.layers.3.self_attn.o_proj.zeroPoint shape: [3072] -335: model.layers.3.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] -336: model.layers.3.self_attn.qkv_proj.scale shape: [9216] -337: model.layers.3.self_attn.qkv_proj.zeroPoint shape: [9216] -338: model.layers.30.input_layernorm.weight shape: [3072] -339: model.layers.30.mlp.down_proj.8bit_weight shape: [3072, 8192] -340: model.layers.30.mlp.down_proj.scale shape: [3072] -341: model.layers.30.mlp.down_proj.zeroPoint shape: [3072] -342: model.layers.30.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] -343: model.layers.30.mlp.gate_up_proj.scale shape: [16384] -344: model.layers.30.mlp.gate_up_proj.zeroPoint shape: [16384] -345: model.layers.30.post_attention_layernorm.weight shape: [3072] -346: model.layers.30.self_attn.o_proj.8bit_weight shape: [3072, 3072] -347: model.layers.30.self_attn.o_proj.scale shape: [3072] -348: model.layers.30.self_attn.o_proj.zeroPoint shape: [3072] -349: model.layers.30.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] -350: model.layers.30.self_attn.qkv_proj.scale shape: [9216] -351: model.layers.30.self_attn.qkv_proj.zeroPoint shape: [9216] -352: model.layers.31.input_layernorm.weight shape: [3072] -353: model.layers.31.mlp.down_proj.8bit_weight shape: [3072, 8192] -354: model.layers.31.mlp.down_proj.scale shape: [3072] -355: model.layers.31.mlp.down_proj.zeroPoint shape: [3072] -356: model.layers.31.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] -357: model.layers.31.mlp.gate_up_proj.scale shape: [16384] -358: model.layers.31.mlp.gate_up_proj.zeroPoint shape: [16384] -359: model.layers.31.post_attention_layernorm.weight shape: [3072] -360: model.layers.31.self_attn.o_proj.8bit_weight shape: [3072, 3072] -361: model.layers.31.self_attn.o_proj.scale shape: [3072] -362: model.layers.31.self_attn.o_proj.zeroPoint shape: [3072] -363: model.layers.31.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] -364: model.layers.31.self_attn.qkv_proj.scale shape: [9216] -365: model.layers.31.self_attn.qkv_proj.zeroPoint shape: [9216] -366: model.layers.4.input_layernorm.weight shape: [3072] -367: model.layers.4.mlp.down_proj.8bit_weight shape: [3072, 8192] -368: model.layers.4.mlp.down_proj.scale shape: [3072] -369: model.layers.4.mlp.down_proj.zeroPoint shape: [3072] -370: model.layers.4.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] -371: model.layers.4.mlp.gate_up_proj.scale shape: [16384] -372: model.layers.4.mlp.gate_up_proj.zeroPoint shape: [16384] -373: model.layers.4.post_attention_layernorm.weight shape: [3072] -374: model.layers.4.self_attn.o_proj.8bit_weight shape: [3072, 3072] -375: model.layers.4.self_attn.o_proj.scale shape: [3072] -376: model.layers.4.self_attn.o_proj.zeroPoint shape: [3072] -377: model.layers.4.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] -378: model.layers.4.self_attn.qkv_proj.scale shape: [9216] -379: model.layers.4.self_attn.qkv_proj.zeroPoint shape: [9216] -380: model.layers.5.input_layernorm.weight shape: [3072] -381: model.layers.5.mlp.down_proj.8bit_weight shape: [3072, 8192] -382: model.layers.5.mlp.down_proj.scale shape: [3072] -383: model.layers.5.mlp.down_proj.zeroPoint shape: [3072] -384: model.layers.5.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] -385: model.layers.5.mlp.gate_up_proj.scale shape: [16384] -386: model.layers.5.mlp.gate_up_proj.zeroPoint shape: [16384] -387: model.layers.5.post_attention_layernorm.weight shape: [3072] -388: model.layers.5.self_attn.o_proj.8bit_weight shape: [3072, 3072] -389: model.layers.5.self_attn.o_proj.scale shape: [3072] -390: model.layers.5.self_attn.o_proj.zeroPoint shape: [3072] -391: model.layers.5.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] -392: model.layers.5.self_attn.qkv_proj.scale shape: [9216] -393: model.layers.5.self_attn.qkv_proj.zeroPoint shape: [9216] -394: model.layers.6.input_layernorm.weight shape: [3072] -395: model.layers.6.mlp.down_proj.8bit_weight shape: [3072, 8192] -396: model.layers.6.mlp.down_proj.scale shape: [3072] -397: model.layers.6.mlp.down_proj.zeroPoint shape: [3072] -398: model.layers.6.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] -399: model.layers.6.mlp.gate_up_proj.scale shape: [16384] -400: model.layers.6.mlp.gate_up_proj.zeroPoint shape: [16384] -401: model.layers.6.post_attention_layernorm.weight shape: [3072] -402: model.layers.6.self_attn.o_proj.8bit_weight shape: [3072, 3072] -403: model.layers.6.self_attn.o_proj.scale shape: [3072] -404: model.layers.6.self_attn.o_proj.zeroPoint shape: [3072] -405: model.layers.6.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] -406: model.layers.6.self_attn.qkv_proj.scale shape: [9216] -407: model.layers.6.self_attn.qkv_proj.zeroPoint shape: [9216] -408: model.layers.7.input_layernorm.weight shape: [3072] -409: model.layers.7.mlp.down_proj.8bit_weight shape: [3072, 8192] -410: model.layers.7.mlp.down_proj.scale shape: [3072] -411: model.layers.7.mlp.down_proj.zeroPoint shape: [3072] -412: model.layers.7.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] -413: model.layers.7.mlp.gate_up_proj.scale shape: [16384] -414: model.layers.7.mlp.gate_up_proj.zeroPoint shape: [16384] -415: model.layers.7.post_attention_layernorm.weight shape: [3072] -416: model.layers.7.self_attn.o_proj.8bit_weight shape: [3072, 3072] -417: model.layers.7.self_attn.o_proj.scale shape: [3072] -418: model.layers.7.self_attn.o_proj.zeroPoint shape: [3072] -419: model.layers.7.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] -420: model.layers.7.self_attn.qkv_proj.scale shape: [9216] -421: model.layers.7.self_attn.qkv_proj.zeroPoint shape: [9216] -422: model.layers.8.input_layernorm.weight shape: [3072] -423: model.layers.8.mlp.down_proj.8bit_weight shape: [3072, 8192] -424: model.layers.8.mlp.down_proj.scale shape: [3072] -425: model.layers.8.mlp.down_proj.zeroPoint shape: [3072] -426: model.layers.8.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] -427: model.layers.8.mlp.gate_up_proj.scale shape: [16384] -428: model.layers.8.mlp.gate_up_proj.zeroPoint shape: [16384] -429: model.layers.8.post_attention_layernorm.weight shape: [3072] -430: model.layers.8.self_attn.o_proj.8bit_weight shape: [3072, 3072] -431: model.layers.8.self_attn.o_proj.scale shape: [3072] -432: model.layers.8.self_attn.o_proj.zeroPoint shape: [3072] -433: model.layers.8.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] -434: model.layers.8.self_attn.qkv_proj.scale shape: [9216] -435: model.layers.8.self_attn.qkv_proj.zeroPoint shape: [9216] -436: model.layers.9.input_layernorm.weight shape: [3072] -437: model.layers.9.mlp.down_proj.8bit_weight shape: [3072, 8192] -438: model.layers.9.mlp.down_proj.scale shape: [3072] -439: model.layers.9.mlp.down_proj.zeroPoint shape: [3072] -440: model.layers.9.mlp.gate_up_proj.8bit_weight shape: [16384, 3072] -441: model.layers.9.mlp.gate_up_proj.scale shape: [16384] -442: model.layers.9.mlp.gate_up_proj.zeroPoint shape: [16384] -443: model.layers.9.post_attention_layernorm.weight shape: [3072] -444: model.layers.9.self_attn.o_proj.8bit_weight shape: [3072, 3072] -445: model.layers.9.self_attn.o_proj.scale shape: [3072] -446: model.layers.9.self_attn.o_proj.zeroPoint shape: [3072] -447: model.layers.9.self_attn.qkv_proj.8bit_weight shape: [9216, 3072] -448: model.layers.9.self_attn.qkv_proj.scale shape: [9216] -449: model.layers.9.self_attn.qkv_proj.zeroPoint shape: [9216] -450: model.norm.weight shape: [3072] From 73c0d31cf363913a5f54b1e01da8837cf3518404 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Mon, 1 Jul 2024 13:40:26 -0700 Subject: [PATCH 19/41] update --- .../Microsoft.ML.GenAI.Core.Tests.csproj | 8 +------- .../QuantizedLinearTests.cs | 17 +++++------------ .../AutoGenTests.cs | 9 +-------- .../Microsoft.ML.GenAI.Phi.Tests.csproj | 8 +------- test/Microsoft.ML.GenAI.Phi.Tests/Phi2Test.cs | 17 +---------------- test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs | 11 +---------- .../SemanticKernelTests.cs | 9 +-------- 7 files changed, 11 insertions(+), 68 deletions(-) diff --git a/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj b/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj index bc0eea0bb6..8611d2a701 100644 --- a/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj +++ b/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj @@ -3,12 +3,12 @@ net6.0 enable + $(NoWarn);MSML_ExtendBaseTestClass enable - @@ -17,12 +17,6 @@ - - - - - - diff --git a/test/Microsoft.ML.GenAI.Core.Tests/QuantizedLinearTests.cs b/test/Microsoft.ML.GenAI.Core.Tests/QuantizedLinearTests.cs index e9687454f4..d1653721ba 100644 --- a/test/Microsoft.ML.GenAI.Core.Tests/QuantizedLinearTests.cs +++ b/test/Microsoft.ML.GenAI.Core.Tests/QuantizedLinearTests.cs @@ -1,24 +1,17 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + using FluentAssertions; using Microsoft.ML.GenAI.Core.Extension; -using Microsoft.ML.TestFramework; using TorchSharp; using Xunit; -using Xunit.Abstractions; using static TorchSharp.torch; namespace Microsoft.ML.GenAI.Core.Tests; -public class QuantizedLinearTests : BaseTestClass +public class QuantizedLinearTests { - public QuantizedLinearTests(ITestOutputHelper output) : base(output) - { - } - [Fact] public void Int4QuantizeSizeTests() { diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/AutoGenTests.cs b/test/Microsoft.ML.GenAI.Phi.Tests/AutoGenTests.cs index e08e496eff..33ab565fe7 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/AutoGenTests.cs +++ b/test/Microsoft.ML.GenAI.Phi.Tests/AutoGenTests.cs @@ -5,21 +5,14 @@ using AutoGen.Core; using FluentAssertions; using Microsoft.ML.GenAI.Core; -using Microsoft.ML.TestFramework; using Microsoft.ML.Tokenizers; using Moq; using Xunit; -using Xunit.Abstractions; namespace Microsoft.ML.GenAI.Phi.Tests; -public class AutoGenTests : BaseTestClass +public class AutoGenTests { - public AutoGenTests(ITestOutputHelper helper) - : base(helper) - { - } - [Fact] public async Task ItGenerateTextReply() { diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj b/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj index bf4580ee9f..338833327f 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj @@ -3,13 +3,13 @@ net6.0 enable + $(NoWarn);MSML_ExtendBaseTestClass enable - @@ -21,12 +21,6 @@ - - - - - - diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Test.cs b/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Test.cs index 3d99cd986d..52863216cb 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Test.cs +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Test.cs @@ -2,33 +2,18 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Collections.Generic; -using System.Linq; using System.Text; -using System.Text.Json; -using System.Threading.Tasks; using ApprovalTests; using ApprovalTests.Namers; using ApprovalTests.Reporters; using FluentAssertions; using Microsoft.ML.GenAI.Core.Extension; -using Microsoft.ML.GenAI.Phi.Module; -using Microsoft.ML.TestFramework; using Microsoft.ML.Tokenizers; -using TorchSharp; using Xunit; -using Xunit.Abstractions; -using static TorchSharp.torch; namespace Microsoft.ML.GenAI.Phi.Tests; -public class Phi2Test : BaseTestClass +public class Phi2Test { - public Phi2Test(ITestOutputHelper output) : base(output) - { - torch.set_default_device("meta"); - } - [Fact] [UseReporter(typeof(DiffReporter))] [UseApprovalSubdirectory("Approvals")] diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs b/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs index 543a415380..385d439337 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs @@ -9,21 +9,12 @@ using ApprovalTests.Reporters; using FluentAssertions; using Microsoft.ML.GenAI.Core.Extension; -using Microsoft.ML.TestFramework; -using TorchSharp; using Xunit; -using Xunit.Abstractions; -using static TorchSharp.torch; namespace Microsoft.ML.GenAI.Phi.Tests; -public class Phi3Tests : BaseTestClass +public class Phi3Tests { - public Phi3Tests(ITestOutputHelper output) : base(output) - { - torch.set_default_device("meta"); - } - [Fact] [UseReporter(typeof(DiffReporter))] [UseApprovalSubdirectory("Approvals")] diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/SemanticKernelTests.cs b/test/Microsoft.ML.GenAI.Phi.Tests/SemanticKernelTests.cs index 09c11c8886..3bed3b5a79 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/SemanticKernelTests.cs +++ b/test/Microsoft.ML.GenAI.Phi.Tests/SemanticKernelTests.cs @@ -6,23 +6,16 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.ML.GenAI.Core; using Microsoft.ML.GenAI.Phi.Extension; -using Microsoft.ML.TestFramework; using Microsoft.ML.Tokenizers; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.ChatCompletion; using Moq; using Xunit; -using Xunit.Abstractions; namespace Microsoft.ML.GenAI.Phi.Tests; -public class SemanticKernelTests : BaseTestClass +public class SemanticKernelTests { - public SemanticKernelTests(ITestOutputHelper helper) - : base(helper) - { - } - [Fact] public async Task ItAddPhi3CausalLMChatCompletionServiceTestAsync() { From 745f40bb3882e8b2886c57d40d807b9ea368e1d6 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Mon, 15 Jul 2024 10:36:45 -0700 Subject: [PATCH 20/41] remove Phi3Tokenizer and use LlamaTokenizer instead --- .../Phi3Mini/AutoGenSample.cs | 2 +- .../Phi3Mini/SemanticKernelSample.cs | 2 +- .../Phi3Mini/Utils.cs | 8 +- .../Pipeline/CausalLMPipeline.cs | 10 +- .../Phi2/Phi2Tokenzier.cs | 30 +++ .../Phi3/Phi3Tokenzier.cs | 192 +++--------------- src/Microsoft.ML.Tokenizers/Tokenizer.cs | 15 -- .../Phi3Tests.TokenizerTest.approved.txt | 14 +- test/Microsoft.ML.GenAI.Phi.Tests/Phi2Test.cs | 2 +- .../Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs | 9 +- 10 files changed, 90 insertions(+), 194 deletions(-) create mode 100644 src/Microsoft.ML.GenAI.Phi/Phi2/Phi2Tokenzier.cs diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/AutoGenSample.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/AutoGenSample.cs index 392aec674d..379fd2b97b 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/AutoGenSample.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/AutoGenSample.cs @@ -26,7 +26,7 @@ public static async Task RunAsync() torch.manual_seed(1); torch.set_default_dtype(defaultType); var weightFolder = @"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct"; - var pipeline = Utils.LoadPhi3Mini4KFromFolder(weightFolder, device: device, quantizeToInt8: false); + var pipeline = Utils.LoadPhi3Mini4KFromFolder(weightFolder, device: device); // agent var agent = new Phi3Agent(pipeline, "assistant") diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/SemanticKernelSample.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/SemanticKernelSample.cs index 8ce0ead983..d82bd20704 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/SemanticKernelSample.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/SemanticKernelSample.cs @@ -56,7 +56,7 @@ public static async Task RunTextGenerationSample() .AddPhi3AsTextGeneration(pipeline) .Build(); - var response = await kernel.InvokePromptAsync("write a C# program to calculate the factorial of a number"); + var response = await kernel.InvokePromptAsync("Tell a joke"); Console.WriteLine(response); } } diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs index f8147fa822..5e53ef0ac4 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs @@ -10,12 +10,13 @@ using TorchSharp; using Microsoft.ML.GenAI.Core.Extension; using System.Text.Json; +using Microsoft.ML.Tokenizers; namespace Microsoft.ML.GenAI.Samples.Phi3Mini; internal static class Utils { - public static CausalLMPipeline LoadPhi3Mini4KFromFolder( + public static ICausalLMPipeline LoadPhi3Mini4KFromFolder( string weightFolder, string configName = "config.json", string device = "cuda", @@ -31,7 +32,8 @@ public static CausalLMPipeline LoadPhi3Mini4KFro var config = JsonSerializer.Deserialize(System.IO.File.ReadAllText(configPath)) ?? throw new ArgumentNullException(nameof(configPath)); var timer = System.Diagnostics.Stopwatch.StartNew(); var model = new Phi3ForCasualLM(config); - var tokenizer = Phi3Tokenizer.FromPretrained(weightFolder); + var tokenzierPath = System.IO.Path.Combine(weightFolder, "tokenizer.model"); + var tokenizer = Phi3TokenizerHelper.FromPretrained(tokenzierPath); if (quantizeToInt8) { @@ -93,7 +95,7 @@ public static CausalLMPipeline LoadPhi3Mini4KFro model = model.ToDynamicLoadingModel(deviceMap, "cuda"); timer.Stop(); Console.WriteLine($"Phi3 loaded to device: {device} in {timer.ElapsedMilliseconds / 1000} s"); - var pipeline = new CausalLMPipeline(tokenizer, model, device); + var pipeline = new CausalLMPipeline(tokenizer, model, device); torch.set_default_device(device); return pipeline; diff --git a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs index 57dc7f88c7..1a4d443ccb 100644 --- a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs +++ b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs @@ -244,7 +244,15 @@ public virtual IEnumerable GenerateStreaming( List stopTokenIds = [[]]; if (stopSequences != null) { - stopTokenIds.AddRange(stopSequences.Select(x => this.Tokenizer.EncodeToIds(x, false, false).ToArray())); + stopTokenIds.AddRange(stopSequences.Select(x => + { + var tokens = this.Tokenizer.EncodeToTokens(x, out var _, false, false); + + return tokens + .Where(t => t.Offset != (0, 0)) + .Select(t => t.Id) + .ToArray(); + })); } stopTokenIds = stopTokenIds.Where(ids => ids.Count() > 0).ToList(); diff --git a/src/Microsoft.ML.GenAI.Phi/Phi2/Phi2Tokenzier.cs b/src/Microsoft.ML.GenAI.Phi/Phi2/Phi2Tokenzier.cs new file mode 100644 index 0000000000..3444c74e31 --- /dev/null +++ b/src/Microsoft.ML.GenAI.Phi/Phi2/Phi2Tokenzier.cs @@ -0,0 +1,30 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Text; +using System.Text.RegularExpressions; +using Microsoft.ML.Tokenizers; +using Tensorboard; + +/// +/// The utility class to create tokenizer for phi-3 model. +/// +public class Phi2TokenizerHelper +{ + public static CodeGenTokenizer Create( + string folder, + string vocabFile = "vocab.json", + string mergesFile = "merges.txt", + bool addPrefixSpace = false, + bool addBeginOfSentence = false, + bool addEndOfSentence = false) + { + var vocabPath = Path.Combine(folder, vocabFile); + var mergesPath = Path.Combine(folder, mergesFile); + using var vocabStream = File.OpenRead(vocabPath); + using var mergesStream = File.OpenRead(mergesPath); + + return CodeGenTokenizer.Create(vocabStream, mergesStream, addPrefixSpace, addBeginOfSentence, addEndOfSentence); + } +} diff --git a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3Tokenzier.cs b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3Tokenzier.cs index 11116cdb72..01a0af3061 100644 --- a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3Tokenzier.cs +++ b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3Tokenzier.cs @@ -6,10 +6,11 @@ using System.Text.RegularExpressions; using Microsoft.ML.Tokenizers; -public class Phi3Tokenizer : Tokenizer +/// +/// The utility class to create tokenizer for phi-3 model. +/// +public class Phi3TokenizerHelper { - private readonly SentencePieceBpe _tokenizer; - private readonly bool _addPrecedingSpace; private const string SystemSymbol = "<|system|>"; private const string UserSymbol = "<|user|>"; private const string AssistantSymbol = "<|assistant|>"; @@ -26,176 +27,31 @@ public class Phi3Tokenizer : Tokenizer { EndSymbol, EndSymbolId } }; - public Phi3Tokenizer(string modelPath, + public static LlamaTokenizer FromPretrained( + string modelPath, + string systemSymbol = SystemSymbol, + string userSymbol = UserSymbol, + string assistantSymbol = AssistantSymbol, + string endSymbol = EndSymbol, + int systemSymbolId = SystemSymbolId, + int userSymbolId = UserSymbolId, + int assistantSymbolId = AssistantSymbolId, + int endSymbolId = EndSymbolId, bool addPrecedingSpace = true) { var modelStream = File.OpenRead(modelPath); - this._addPrecedingSpace = addPrecedingSpace; - this._tokenizer = (SentencePieceBpe)Tokenizer.CreateLlama(modelStream, false, false); - } - - public static Phi3Tokenizer FromPretrained( - string folder, - string modelName = "tokenizer.model") - { - return new Phi3Tokenizer(Path.Combine(folder, modelName)); - } - - public int BosId { get => this._tokenizer.BeginningOfSentenceId; } - - public int EosId { get => this._tokenizer.EndOfSentenceId; } - - public override IReadOnlyList Encode(ReadOnlySpan text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) - { - var tokens = new List(); - var normalizedText = new StringBuilder(); - var input = text.ToString(); - - // step 1: - // replace all special tokens to - var re = new Regex($"{SystemSymbol.Replace("|", "\\|")}|{UserSymbol.Replace("|", "\\|")}|{AssistantSymbol.Replace("|", "\\|")}|{EndSymbol.Replace("|", "\\|")}"); - var matches = re.Matches(input); - var matchesList = new List(); - foreach (Match match in matches) - { - // replace the first special tokens with - var specialToken = match.Value; - var index = input.IndexOf(specialToken); - var subString = input.Substring(0, index); - var subTokens = this._tokenizer.Encode(subString, out var subNormalizeString, addBeginningOfSentence: false, addEndOfSentence: false, considerPreTokenization: considerPreTokenization, considerNormalization: considerNormalization).ToArray(); - normalizedText.Append(subNormalizeString); - tokens.AddRange(subTokens); - tokens.Add(new Token(this._specialTokenMap[specialToken], specialToken, (index, specialToken.Length))); - input = input.Remove(0, index + specialToken.Length); - } - - tokens.AddRange(this._tokenizer.Encode(input, out var normailzeString, addBeginningOfSentence: false, addEndOfSentence: false, considerPreTokenization: considerPreTokenization, considerNormalization: considerNormalization).ToArray()); - - normalizedText.Append(normailzeString); - normalizedString = normalizedText.ToString(); - - return tokens.ToArray(); - } - - public IReadOnlyList EncodeToIds( - string text, - bool addBeginningOfSentence, - bool addEndOfSentence, - bool considerPreTokenization = true, - bool considerNormalization = true) - { - var ids = this.EncodeToIds(text, considerPreTokenization: considerPreTokenization, considerNormalization: considerNormalization); - - if (addBeginningOfSentence) - { - ids = new int[] { this.BosId }.Concat(ids).ToArray(); - } - - if (addEndOfSentence) - { - ids = ids.Concat(new int[] { this.EosId }).ToArray(); - } - - return ids; - } - - public override IReadOnlyList EncodeToIds(ReadOnlySpan text, bool considerPreTokenization = true, bool considerNormalization = true) - { - var input = text.ToString(); - // step 1: - // replace all special tokens to - var re = new Regex($"{SystemSymbol.Replace("|", "\\|")}|{UserSymbol.Replace("|", "\\|")}|{AssistantSymbol.Replace("|", "\\|")}|{EndSymbol.Replace("|", "\\|")}"); - var matches = re.Matches(input); - var matchesList = new List(); - var tokens = new List(); - foreach (Match match in matches) - { - var specialToken = match.Value; - var index = input.IndexOf(specialToken); - var subString = input.Substring(0, index); - var subTokens = this._tokenizer.EncodeToIds(subString, addBeginningOfSentence: false, addEndOfSentence: false, considerPreTokenization: false, considerNormalization: true).ToArray(); - // remove the first sub Token as it will always be '_' - tokens.AddRange(subTokens); - tokens.Add(this._specialTokenMap[specialToken]); - input = input.Remove(0, index + specialToken.Length); - } - - tokens.AddRange(this._tokenizer.EncodeToIds(input, addBeginningOfSentence: false, addEndOfSentence: false, considerPreTokenization: false, considerNormalization: true).ToArray()); - - return tokens.ToArray(); - } - - public override IReadOnlyList EncodeToIds(ReadOnlySpan text, int maxTokenCount, out string? normalizedText, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true) - { - var tokens = this.Encode(text, out normalizedText, considerPreTokenization, considerNormalization); - var tokenIds = tokens.Select(x => x.Id).ToArray(); - - textLength = normalizedText?.Length ?? 0; - - return tokenIds.Length > maxTokenCount ? tokenIds.Take(maxTokenCount).ToArray() : tokenIds; - } - - public override int CountTokens(ReadOnlySpan text, bool considerPreTokenization = true, bool considerNormalization = true) - { - var tokens = this.EncodeToIds(text, considerPreTokenization, considerNormalization); - - return tokens.Count; - } - - public override int IndexOfTokenCount(ReadOnlySpan text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) - { - return _tokenizer.IndexOfTokenCount(text, maxTokenCount, out normalizedString, out tokenCount, considerPreTokenization, considerNormalization); - } - - public override int LastIndexOfTokenCount(ReadOnlySpan text, int maxTokenCount, out string? processedText, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) - { - return _tokenizer.LastIndexOfTokenCount(text, maxTokenCount, out processedText, out tokenCount, considerPreTokenization, considerNormalization); - } - - public override string? Decode(IEnumerable ids) - { - // step 1 - // replace all special token ids to ukn ids - var replacedIds = ids.SelectMany(id => - { - if (this._specialTokenMap.ContainsValue(id)) - { - var key = this._specialTokenMap.First(x => x.Value == id).Key; - var ids = this._tokenizer.EncodeToIds(key, false, false, false, false); - var recoverKey = this._tokenizer.Decode(ids) ?? throw new Exception("Failed to decode ids"); - return ids; - } - else + var llamaTokenizer = LlamaTokenizer.Create( + modelStream, + addPrecedingSpace, + specialTokens: new Dictionary { - return new List { id }; - } - }); - - var str = this._tokenizer.Decode(replacedIds) ?? throw new Exception("Failed to decode ids"); - - return str; - } - - public override int? MapTokenToId(ReadOnlySpan token) - { - // check if token in special tokens - var tokenStr = token.ToString(); - if (_specialTokenMap.ContainsKey(tokenStr)) - { - return _specialTokenMap[tokenStr]; - } - - return _tokenizer.MapTokenToId(token); - } - - public override string? MapIdToToken(int id) - { - if (_specialTokenMap.ContainsValue(id)) - { - return _specialTokenMap.First(x => x.Value == id).Key; - } + { systemSymbol, systemSymbolId }, + { userSymbol, userSymbolId }, + { assistantSymbol, assistantSymbolId }, + { endSymbol, endSymbolId } + }); - return _tokenizer.MapIdToToken(id); + return llamaTokenizer; } } diff --git a/src/Microsoft.ML.Tokenizers/Tokenizer.cs b/src/Microsoft.ML.Tokenizers/Tokenizer.cs index bd8da3e3be..3c930e1f10 100644 --- a/src/Microsoft.ML.Tokenizers/Tokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Tokenizer.cs @@ -305,21 +305,6 @@ public int GetIndexByTokenCountFromEnd(ReadOnlySpan text, int maxTokenCoun /// The operation status indicates whether all IDs were successfully decoded or if the is too small to contain the entire decoded result. public abstract OperationStatus Decode(IEnumerable ids, Span destination, out int idsConsumed, out int charsWritten); - public static CodeGen CreatePhi2( - string folder, - string vocabFile = "vocab.json", - string mergesFile = "merges.txt", - bool addPrefixSpace = false, - bool addBeginOfSentence = false, - bool addEndOfSentence = false) - { - var vocabPath = Path.Combine(folder, vocabFile); - var mergesPath = Path.Combine(folder, mergesFile); - using var vocabStream = File.OpenRead(vocabPath); - using var mergesStream = File.OpenRead(mergesPath); - - return (CodeGen)CreateCodeGen(vocabStream, mergesStream, addPrefixSpace, addBeginOfSentence, addEndOfSentence); - } internal static IEnumerable<(int Offset, int Length)>? InitializeForEncoding( string? text, diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.TokenizerTest.approved.txt b/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.TokenizerTest.approved.txt index 95b1fb630d..70624d24df 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.TokenizerTest.approved.txt +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi3Tests.TokenizerTest.approved.txt @@ -1,8 +1,20 @@ +Can you provide ways to eat combinations of bananas and dragonfruits? 1, 1815, 366, 3867, 5837, 304, 17545, 18240, 310, 9892, 16397, 322, 8338, 265, 29888, 21211, 29973 +Sure! Here are some ways to eat bananas and dragonfruits together: 1. Banana and dragonfruit smoothie: Blend bananas and dragonfruits together with some milk and honey. 2. Banana and dragonfruit salad: Mix sliced bananas and dragonfruits together with some lemon juice and honey. 1, 18585, 29991, 2266, 526, 777, 5837, 304, 17545, 9892, 16397, 322, 8338, 265, 29888, 21211, 4208, 29901, 29871, 29896, 29889, 10765, 1648, 322, 8338, 265, 29888, 9216, 10597, 347, 29901, 3164, 355, 9892, 16397, 322, 8338, 265, 29888, 21211, 4208, 411, 777, 27274, 322, 298, 4992, 29889, 29871, 29906, 29889, 10765, 1648, 322, 8338, 265, 29888, 9216, 4497, 328, 29901, 23478, 269, 506, 287, 9892, 16397, 322, 8338, 265, 29888, 21211, 4208, 411, 777, 454, 3712, 3623, 625, 322, 298, 4992, 29889 +What about solving an 2x + 3 = 7 equation? 1, 1724, 1048, 17069, 385, 29871, 29906, 29916, 718, 29871, 29941, 353, 29871, 29955, 6306, 29973 + +Count to 3 + 1, 29871, 13, 3981, 304, 29871, 29941, 13 +<|user|> 1, 32010 +<|end|> 1, 32007 +<|assistant|> 1, 32001 -1, 32010, 29871, 13, 3981, 304, 29871, 29941, 32007, 29871, 13, 32001 +<|user|> +Count to 3<|end|> +<|assistant|> +1, 32010, 29871, 13, 3981, 304, 29871, 29941, 32007, 13, 32001 diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Test.cs b/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Test.cs index 52863216cb..16b02d8eff 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Test.cs +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Test.cs @@ -30,7 +30,7 @@ public void LoadSafeTensorShapeTest() public void TokenizerTest() { var modelWeightFolder = Path.Join("Phi-2"); - var tokenizer = Tokenizer.CreatePhi2(modelWeightFolder, addBeginOfSentence: true); + var tokenizer = Phi2TokenizerHelper.Create(modelWeightFolder, addBeginOfSentence: true); tokenizer.EndOfSentenceId.Should().Be(50256); tokenizer.BeginningOfSentenceId.Should().Be(50256); var messages = new string[] diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs b/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs index 385d439337..b65d0bf995 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs @@ -104,9 +104,10 @@ public void Phi3Mini128KLayerSizeTest() public void TokenizerTest() { var modelWeightFolder = Path.Join("Llama"); - var tokenizer = Phi3Tokenizer.FromPretrained(modelWeightFolder); - tokenizer.BosId.Should().Be(1); - tokenizer.EosId.Should().Be(2); + var modelPath = Path.Join(modelWeightFolder, "tokenizer.model"); + var tokenizer = Phi3TokenizerHelper.FromPretrained(modelPath); + tokenizer.BeginningOfSentenceId.Should().Be(1); + tokenizer.EndOfSentenceId.Should().Be(2); // test <|end|> var endIds = tokenizer.EncodeToIds("<|end|>", addBeginningOfSentence: false, addEndOfSentence: false, considerPreTokenization: false, considerNormalization: false); @@ -127,6 +128,8 @@ public void TokenizerTest() foreach (var message in messages) { var tokenizeIds = tokenizer.EncodeToIds(message, true, false, considerPreTokenization: true); + var decodeToString = tokenizer.Decode(tokenizeIds, considerSpecialTokens: true); + sb.AppendLine(decodeToString); var tokenizedStr = string.Join(", ", tokenizeIds.Select(x => x.ToString())); sb.AppendLine(tokenizedStr); From 57444cc7089c07e9396c72f2a035a3dfb2598fb0 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Mon, 15 Jul 2024 10:38:03 -0700 Subject: [PATCH 21/41] revert change in tokenizer package --- src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj | 4 ---- src/Microsoft.ML.Tokenizers/Tokenizer.cs | 1 - 2 files changed, 5 deletions(-) diff --git a/src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj b/src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj index a61041f8e1..fbff32071e 100644 --- a/src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj +++ b/src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj @@ -97,10 +97,6 @@ - - - - text, int maxTokenCoun /// The operation status indicates whether all IDs were successfully decoded or if the is too small to contain the entire decoded result. public abstract OperationStatus Decode(IEnumerable ids, Span destination, out int idsConsumed, out int charsWritten); - internal static IEnumerable<(int Offset, int Length)>? InitializeForEncoding( string? text, ReadOnlySpan textSpan, From 4745683201099d70bbf032712b1ff3e43f815c52 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Wed, 17 Jul 2024 13:39:18 -0700 Subject: [PATCH 22/41] run test on x64 --- .../Microsoft.ML.GenAI.Core.Tests.csproj | 4 ++++ .../Microsoft.ML.GenAI.Phi.Tests.csproj | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj b/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj index 8611d2a701..aa1c14c620 100644 --- a/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj +++ b/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj @@ -17,6 +17,10 @@ + + + + diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj b/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj index 338833327f..87d50c122c 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj @@ -21,6 +21,11 @@ + + + + + From b933ce4d380f3ab42b9f1003313b6c8b6db8f85a Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Wed, 17 Jul 2024 13:58:20 -0700 Subject: [PATCH 23/41] fix tests --- .../Microsoft.ML.GenAI.Phi.Tests.csproj | 10 +++++----- .../{Phi2Test.cs => Phi2Tests.cs} | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) rename test/Microsoft.ML.GenAI.Phi.Tests/{Phi2Test.cs => Phi2Tests.cs} (98%) diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj b/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj index 87d50c122c..c520ba7db3 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj @@ -22,14 +22,14 @@ - - + + - - - + + + diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Test.cs b/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Tests.cs similarity index 98% rename from test/Microsoft.ML.GenAI.Phi.Tests/Phi2Test.cs rename to test/Microsoft.ML.GenAI.Phi.Tests/Phi2Tests.cs index 16b02d8eff..2a938de68d 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Test.cs +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Tests.cs @@ -12,7 +12,7 @@ using Xunit; namespace Microsoft.ML.GenAI.Phi.Tests; -public class Phi2Test +public class Phi2Tests { [Fact] [UseReporter(typeof(DiffReporter))] From 43dd37f99375c855814717cb891317d1d8eacfda Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Wed, 17 Jul 2024 15:00:05 -0700 Subject: [PATCH 24/41] check in approved file --- .../Microsoft.ML.GenAI.Core.Tests.csproj | 2 +- ...Tests.LoadSafeTensorShapeTest.approved.txt | 453 ++++++++++++++++++ .../Phi2Tests.TokenizerTest.approved.txt | 3 + .../Microsoft.ML.GenAI.Phi.Tests.csproj | 2 +- .../Microsoft.ML.GenAI.Phi.Tests/Phi2Tests.cs | 1 + .../Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs | 1 + 6 files changed, 460 insertions(+), 2 deletions(-) create mode 100644 test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi2Tests.LoadSafeTensorShapeTest.approved.txt create mode 100644 test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi2Tests.TokenizerTest.approved.txt diff --git a/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj b/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj index aa1c14c620..612fb7df5b 100644 --- a/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj +++ b/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj @@ -1,7 +1,7 @@  - net6.0 + net6.0;net8.0 enable $(NoWarn);MSML_ExtendBaseTestClass enable diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi2Tests.LoadSafeTensorShapeTest.approved.txt b/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi2Tests.LoadSafeTensorShapeTest.approved.txt new file mode 100644 index 0000000000..75e17ad1a6 --- /dev/null +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi2Tests.LoadSafeTensorShapeTest.approved.txt @@ -0,0 +1,453 @@ +0: lm_head.bias shape: [51200] +1: lm_head.weight shape: [51200, 2560] +2: model.embed_tokens.weight shape: [51200, 2560] +3: model.final_layernorm.bias shape: [2560] +4: model.final_layernorm.weight shape: [2560] +5: model.layers.0.input_layernorm.bias shape: [2560] +6: model.layers.0.input_layernorm.weight shape: [2560] +7: model.layers.0.mlp.fc1.bias shape: [10240] +8: model.layers.0.mlp.fc1.weight shape: [10240, 2560] +9: model.layers.0.mlp.fc2.bias shape: [2560] +10: model.layers.0.mlp.fc2.weight shape: [2560, 10240] +11: model.layers.0.self_attn.dense.bias shape: [2560] +12: model.layers.0.self_attn.dense.weight shape: [2560, 2560] +13: model.layers.0.self_attn.k_proj.bias shape: [2560] +14: model.layers.0.self_attn.k_proj.weight shape: [2560, 2560] +15: model.layers.0.self_attn.q_proj.bias shape: [2560] +16: model.layers.0.self_attn.q_proj.weight shape: [2560, 2560] +17: model.layers.0.self_attn.v_proj.bias shape: [2560] +18: model.layers.0.self_attn.v_proj.weight shape: [2560, 2560] +19: model.layers.1.input_layernorm.bias shape: [2560] +20: model.layers.1.input_layernorm.weight shape: [2560] +21: model.layers.1.mlp.fc1.bias shape: [10240] +22: model.layers.1.mlp.fc1.weight shape: [10240, 2560] +23: model.layers.1.mlp.fc2.bias shape: [2560] +24: model.layers.1.mlp.fc2.weight shape: [2560, 10240] +25: model.layers.1.self_attn.dense.bias shape: [2560] +26: model.layers.1.self_attn.dense.weight shape: [2560, 2560] +27: model.layers.1.self_attn.k_proj.bias shape: [2560] +28: model.layers.1.self_attn.k_proj.weight shape: [2560, 2560] +29: model.layers.1.self_attn.q_proj.bias shape: [2560] +30: model.layers.1.self_attn.q_proj.weight shape: [2560, 2560] +31: model.layers.1.self_attn.v_proj.bias shape: [2560] +32: model.layers.1.self_attn.v_proj.weight shape: [2560, 2560] +33: model.layers.10.input_layernorm.bias shape: [2560] +34: model.layers.10.input_layernorm.weight shape: [2560] +35: model.layers.10.mlp.fc1.bias shape: [10240] +36: model.layers.10.mlp.fc1.weight shape: [10240, 2560] +37: model.layers.10.mlp.fc2.bias shape: [2560] +38: model.layers.10.mlp.fc2.weight shape: [2560, 10240] +39: model.layers.10.self_attn.dense.bias shape: [2560] +40: model.layers.10.self_attn.dense.weight shape: [2560, 2560] +41: model.layers.10.self_attn.k_proj.bias shape: [2560] +42: model.layers.10.self_attn.k_proj.weight shape: [2560, 2560] +43: model.layers.10.self_attn.q_proj.bias shape: [2560] +44: model.layers.10.self_attn.q_proj.weight shape: [2560, 2560] +45: model.layers.10.self_attn.v_proj.bias shape: [2560] +46: model.layers.10.self_attn.v_proj.weight shape: [2560, 2560] +47: model.layers.11.input_layernorm.bias shape: [2560] +48: model.layers.11.input_layernorm.weight shape: [2560] +49: model.layers.11.mlp.fc1.bias shape: [10240] +50: model.layers.11.mlp.fc1.weight shape: [10240, 2560] +51: model.layers.11.mlp.fc2.bias shape: [2560] +52: model.layers.11.mlp.fc2.weight shape: [2560, 10240] +53: model.layers.11.self_attn.dense.bias shape: [2560] +54: model.layers.11.self_attn.dense.weight shape: [2560, 2560] +55: model.layers.11.self_attn.k_proj.bias shape: [2560] +56: model.layers.11.self_attn.k_proj.weight shape: [2560, 2560] +57: model.layers.11.self_attn.q_proj.bias shape: [2560] +58: model.layers.11.self_attn.q_proj.weight shape: [2560, 2560] +59: model.layers.11.self_attn.v_proj.bias shape: [2560] +60: model.layers.11.self_attn.v_proj.weight shape: [2560, 2560] +61: model.layers.12.input_layernorm.bias shape: [2560] +62: model.layers.12.input_layernorm.weight shape: [2560] +63: model.layers.12.mlp.fc1.bias shape: [10240] +64: model.layers.12.mlp.fc1.weight shape: [10240, 2560] +65: model.layers.12.mlp.fc2.bias shape: [2560] +66: model.layers.12.mlp.fc2.weight shape: [2560, 10240] +67: model.layers.12.self_attn.dense.bias shape: [2560] +68: model.layers.12.self_attn.dense.weight shape: [2560, 2560] +69: model.layers.12.self_attn.k_proj.bias shape: [2560] +70: model.layers.12.self_attn.k_proj.weight shape: [2560, 2560] +71: model.layers.12.self_attn.q_proj.bias shape: [2560] +72: model.layers.12.self_attn.q_proj.weight shape: [2560, 2560] +73: model.layers.12.self_attn.v_proj.bias shape: [2560] +74: model.layers.12.self_attn.v_proj.weight shape: [2560, 2560] +75: model.layers.13.input_layernorm.bias shape: [2560] +76: model.layers.13.input_layernorm.weight shape: [2560] +77: model.layers.13.mlp.fc1.bias shape: [10240] +78: model.layers.13.mlp.fc1.weight shape: [10240, 2560] +79: model.layers.13.mlp.fc2.bias shape: [2560] +80: model.layers.13.mlp.fc2.weight shape: [2560, 10240] +81: model.layers.13.self_attn.dense.bias shape: [2560] +82: model.layers.13.self_attn.dense.weight shape: [2560, 2560] +83: model.layers.13.self_attn.k_proj.bias shape: [2560] +84: model.layers.13.self_attn.k_proj.weight shape: [2560, 2560] +85: model.layers.13.self_attn.q_proj.bias shape: [2560] +86: model.layers.13.self_attn.q_proj.weight shape: [2560, 2560] +87: model.layers.13.self_attn.v_proj.bias shape: [2560] +88: model.layers.13.self_attn.v_proj.weight shape: [2560, 2560] +89: model.layers.14.input_layernorm.bias shape: [2560] +90: model.layers.14.input_layernorm.weight shape: [2560] +91: model.layers.14.mlp.fc1.bias shape: [10240] +92: model.layers.14.mlp.fc1.weight shape: [10240, 2560] +93: model.layers.14.mlp.fc2.bias shape: [2560] +94: model.layers.14.mlp.fc2.weight shape: [2560, 10240] +95: model.layers.14.self_attn.dense.bias shape: [2560] +96: model.layers.14.self_attn.dense.weight shape: [2560, 2560] +97: model.layers.14.self_attn.k_proj.bias shape: [2560] +98: model.layers.14.self_attn.k_proj.weight shape: [2560, 2560] +99: model.layers.14.self_attn.q_proj.bias shape: [2560] +100: model.layers.14.self_attn.q_proj.weight shape: [2560, 2560] +101: model.layers.14.self_attn.v_proj.bias shape: [2560] +102: model.layers.14.self_attn.v_proj.weight shape: [2560, 2560] +103: model.layers.15.input_layernorm.bias shape: [2560] +104: model.layers.15.input_layernorm.weight shape: [2560] +105: model.layers.15.mlp.fc1.bias shape: [10240] +106: model.layers.15.mlp.fc1.weight shape: [10240, 2560] +107: model.layers.15.mlp.fc2.bias shape: [2560] +108: model.layers.15.mlp.fc2.weight shape: [2560, 10240] +109: model.layers.15.self_attn.dense.bias shape: [2560] +110: model.layers.15.self_attn.dense.weight shape: [2560, 2560] +111: model.layers.15.self_attn.k_proj.bias shape: [2560] +112: model.layers.15.self_attn.k_proj.weight shape: [2560, 2560] +113: model.layers.15.self_attn.q_proj.bias shape: [2560] +114: model.layers.15.self_attn.q_proj.weight shape: [2560, 2560] +115: model.layers.15.self_attn.v_proj.bias shape: [2560] +116: model.layers.15.self_attn.v_proj.weight shape: [2560, 2560] +117: model.layers.16.input_layernorm.bias shape: [2560] +118: model.layers.16.input_layernorm.weight shape: [2560] +119: model.layers.16.mlp.fc1.bias shape: [10240] +120: model.layers.16.mlp.fc1.weight shape: [10240, 2560] +121: model.layers.16.mlp.fc2.bias shape: [2560] +122: model.layers.16.mlp.fc2.weight shape: [2560, 10240] +123: model.layers.16.self_attn.dense.bias shape: [2560] +124: model.layers.16.self_attn.dense.weight shape: [2560, 2560] +125: model.layers.16.self_attn.k_proj.bias shape: [2560] +126: model.layers.16.self_attn.k_proj.weight shape: [2560, 2560] +127: model.layers.16.self_attn.q_proj.bias shape: [2560] +128: model.layers.16.self_attn.q_proj.weight shape: [2560, 2560] +129: model.layers.16.self_attn.v_proj.bias shape: [2560] +130: model.layers.16.self_attn.v_proj.weight shape: [2560, 2560] +131: model.layers.17.input_layernorm.bias shape: [2560] +132: model.layers.17.input_layernorm.weight shape: [2560] +133: model.layers.17.mlp.fc1.bias shape: [10240] +134: model.layers.17.mlp.fc1.weight shape: [10240, 2560] +135: model.layers.17.mlp.fc2.bias shape: [2560] +136: model.layers.17.mlp.fc2.weight shape: [2560, 10240] +137: model.layers.17.self_attn.dense.bias shape: [2560] +138: model.layers.17.self_attn.dense.weight shape: [2560, 2560] +139: model.layers.17.self_attn.k_proj.bias shape: [2560] +140: model.layers.17.self_attn.k_proj.weight shape: [2560, 2560] +141: model.layers.17.self_attn.q_proj.bias shape: [2560] +142: model.layers.17.self_attn.q_proj.weight shape: [2560, 2560] +143: model.layers.17.self_attn.v_proj.bias shape: [2560] +144: model.layers.17.self_attn.v_proj.weight shape: [2560, 2560] +145: model.layers.18.input_layernorm.bias shape: [2560] +146: model.layers.18.input_layernorm.weight shape: [2560] +147: model.layers.18.mlp.fc1.bias shape: [10240] +148: model.layers.18.mlp.fc1.weight shape: [10240, 2560] +149: model.layers.18.mlp.fc2.bias shape: [2560] +150: model.layers.18.mlp.fc2.weight shape: [2560, 10240] +151: model.layers.18.self_attn.dense.bias shape: [2560] +152: model.layers.18.self_attn.dense.weight shape: [2560, 2560] +153: model.layers.18.self_attn.k_proj.bias shape: [2560] +154: model.layers.18.self_attn.k_proj.weight shape: [2560, 2560] +155: model.layers.18.self_attn.q_proj.bias shape: [2560] +156: model.layers.18.self_attn.q_proj.weight shape: [2560, 2560] +157: model.layers.18.self_attn.v_proj.bias shape: [2560] +158: model.layers.18.self_attn.v_proj.weight shape: [2560, 2560] +159: model.layers.19.input_layernorm.bias shape: [2560] +160: model.layers.19.input_layernorm.weight shape: [2560] +161: model.layers.19.mlp.fc1.bias shape: [10240] +162: model.layers.19.mlp.fc1.weight shape: [10240, 2560] +163: model.layers.19.mlp.fc2.bias shape: [2560] +164: model.layers.19.mlp.fc2.weight shape: [2560, 10240] +165: model.layers.19.self_attn.dense.bias shape: [2560] +166: model.layers.19.self_attn.dense.weight shape: [2560, 2560] +167: model.layers.19.self_attn.k_proj.bias shape: [2560] +168: model.layers.19.self_attn.k_proj.weight shape: [2560, 2560] +169: model.layers.19.self_attn.q_proj.bias shape: [2560] +170: model.layers.19.self_attn.q_proj.weight shape: [2560, 2560] +171: model.layers.19.self_attn.v_proj.bias shape: [2560] +172: model.layers.19.self_attn.v_proj.weight shape: [2560, 2560] +173: model.layers.2.input_layernorm.bias shape: [2560] +174: model.layers.2.input_layernorm.weight shape: [2560] +175: model.layers.2.mlp.fc1.bias shape: [10240] +176: model.layers.2.mlp.fc1.weight shape: [10240, 2560] +177: model.layers.2.mlp.fc2.bias shape: [2560] +178: model.layers.2.mlp.fc2.weight shape: [2560, 10240] +179: model.layers.2.self_attn.dense.bias shape: [2560] +180: model.layers.2.self_attn.dense.weight shape: [2560, 2560] +181: model.layers.2.self_attn.k_proj.bias shape: [2560] +182: model.layers.2.self_attn.k_proj.weight shape: [2560, 2560] +183: model.layers.2.self_attn.q_proj.bias shape: [2560] +184: model.layers.2.self_attn.q_proj.weight shape: [2560, 2560] +185: model.layers.2.self_attn.v_proj.bias shape: [2560] +186: model.layers.2.self_attn.v_proj.weight shape: [2560, 2560] +187: model.layers.20.input_layernorm.bias shape: [2560] +188: model.layers.20.input_layernorm.weight shape: [2560] +189: model.layers.20.mlp.fc1.bias shape: [10240] +190: model.layers.20.mlp.fc1.weight shape: [10240, 2560] +191: model.layers.20.mlp.fc2.bias shape: [2560] +192: model.layers.20.mlp.fc2.weight shape: [2560, 10240] +193: model.layers.20.self_attn.dense.bias shape: [2560] +194: model.layers.20.self_attn.dense.weight shape: [2560, 2560] +195: model.layers.20.self_attn.k_proj.bias shape: [2560] +196: model.layers.20.self_attn.k_proj.weight shape: [2560, 2560] +197: model.layers.20.self_attn.q_proj.bias shape: [2560] +198: model.layers.20.self_attn.q_proj.weight shape: [2560, 2560] +199: model.layers.20.self_attn.v_proj.bias shape: [2560] +200: model.layers.20.self_attn.v_proj.weight shape: [2560, 2560] +201: model.layers.21.input_layernorm.bias shape: [2560] +202: model.layers.21.input_layernorm.weight shape: [2560] +203: model.layers.21.mlp.fc1.bias shape: [10240] +204: model.layers.21.mlp.fc1.weight shape: [10240, 2560] +205: model.layers.21.mlp.fc2.bias shape: [2560] +206: model.layers.21.mlp.fc2.weight shape: [2560, 10240] +207: model.layers.21.self_attn.dense.bias shape: [2560] +208: model.layers.21.self_attn.dense.weight shape: [2560, 2560] +209: model.layers.21.self_attn.k_proj.bias shape: [2560] +210: model.layers.21.self_attn.k_proj.weight shape: [2560, 2560] +211: model.layers.21.self_attn.q_proj.bias shape: [2560] +212: model.layers.21.self_attn.q_proj.weight shape: [2560, 2560] +213: model.layers.21.self_attn.v_proj.bias shape: [2560] +214: model.layers.21.self_attn.v_proj.weight shape: [2560, 2560] +215: model.layers.22.input_layernorm.bias shape: [2560] +216: model.layers.22.input_layernorm.weight shape: [2560] +217: model.layers.22.mlp.fc1.bias shape: [10240] +218: model.layers.22.mlp.fc1.weight shape: [10240, 2560] +219: model.layers.22.mlp.fc2.bias shape: [2560] +220: model.layers.22.mlp.fc2.weight shape: [2560, 10240] +221: model.layers.22.self_attn.dense.bias shape: [2560] +222: model.layers.22.self_attn.dense.weight shape: [2560, 2560] +223: model.layers.22.self_attn.k_proj.bias shape: [2560] +224: model.layers.22.self_attn.k_proj.weight shape: [2560, 2560] +225: model.layers.22.self_attn.q_proj.bias shape: [2560] +226: model.layers.22.self_attn.q_proj.weight shape: [2560, 2560] +227: model.layers.22.self_attn.v_proj.bias shape: [2560] +228: model.layers.22.self_attn.v_proj.weight shape: [2560, 2560] +229: model.layers.23.input_layernorm.bias shape: [2560] +230: model.layers.23.input_layernorm.weight shape: [2560] +231: model.layers.23.mlp.fc1.bias shape: [10240] +232: model.layers.23.mlp.fc1.weight shape: [10240, 2560] +233: model.layers.23.mlp.fc2.bias shape: [2560] +234: model.layers.23.mlp.fc2.weight shape: [2560, 10240] +235: model.layers.23.self_attn.dense.bias shape: [2560] +236: model.layers.23.self_attn.dense.weight shape: [2560, 2560] +237: model.layers.23.self_attn.k_proj.bias shape: [2560] +238: model.layers.23.self_attn.k_proj.weight shape: [2560, 2560] +239: model.layers.23.self_attn.q_proj.bias shape: [2560] +240: model.layers.23.self_attn.q_proj.weight shape: [2560, 2560] +241: model.layers.23.self_attn.v_proj.bias shape: [2560] +242: model.layers.23.self_attn.v_proj.weight shape: [2560, 2560] +243: model.layers.24.input_layernorm.bias shape: [2560] +244: model.layers.24.input_layernorm.weight shape: [2560] +245: model.layers.24.mlp.fc1.bias shape: [10240] +246: model.layers.24.mlp.fc1.weight shape: [10240, 2560] +247: model.layers.24.mlp.fc2.bias shape: [2560] +248: model.layers.24.mlp.fc2.weight shape: [2560, 10240] +249: model.layers.24.self_attn.dense.bias shape: [2560] +250: model.layers.24.self_attn.dense.weight shape: [2560, 2560] +251: model.layers.24.self_attn.k_proj.bias shape: [2560] +252: model.layers.24.self_attn.k_proj.weight shape: [2560, 2560] +253: model.layers.24.self_attn.q_proj.bias shape: [2560] +254: model.layers.24.self_attn.q_proj.weight shape: [2560, 2560] +255: model.layers.24.self_attn.v_proj.bias shape: [2560] +256: model.layers.24.self_attn.v_proj.weight shape: [2560, 2560] +257: model.layers.25.input_layernorm.bias shape: [2560] +258: model.layers.25.input_layernorm.weight shape: [2560] +259: model.layers.25.mlp.fc1.bias shape: [10240] +260: model.layers.25.mlp.fc1.weight shape: [10240, 2560] +261: model.layers.25.mlp.fc2.bias shape: [2560] +262: model.layers.25.mlp.fc2.weight shape: [2560, 10240] +263: model.layers.25.self_attn.dense.bias shape: [2560] +264: model.layers.25.self_attn.dense.weight shape: [2560, 2560] +265: model.layers.25.self_attn.k_proj.bias shape: [2560] +266: model.layers.25.self_attn.k_proj.weight shape: [2560, 2560] +267: model.layers.25.self_attn.q_proj.bias shape: [2560] +268: model.layers.25.self_attn.q_proj.weight shape: [2560, 2560] +269: model.layers.25.self_attn.v_proj.bias shape: [2560] +270: model.layers.25.self_attn.v_proj.weight shape: [2560, 2560] +271: model.layers.26.input_layernorm.bias shape: [2560] +272: model.layers.26.input_layernorm.weight shape: [2560] +273: model.layers.26.mlp.fc1.bias shape: [10240] +274: model.layers.26.mlp.fc1.weight shape: [10240, 2560] +275: model.layers.26.mlp.fc2.bias shape: [2560] +276: model.layers.26.mlp.fc2.weight shape: [2560, 10240] +277: model.layers.26.self_attn.dense.bias shape: [2560] +278: model.layers.26.self_attn.dense.weight shape: [2560, 2560] +279: model.layers.26.self_attn.k_proj.bias shape: [2560] +280: model.layers.26.self_attn.k_proj.weight shape: [2560, 2560] +281: model.layers.26.self_attn.q_proj.bias shape: [2560] +282: model.layers.26.self_attn.q_proj.weight shape: [2560, 2560] +283: model.layers.26.self_attn.v_proj.bias shape: [2560] +284: model.layers.26.self_attn.v_proj.weight shape: [2560, 2560] +285: model.layers.27.input_layernorm.bias shape: [2560] +286: model.layers.27.input_layernorm.weight shape: [2560] +287: model.layers.27.mlp.fc1.bias shape: [10240] +288: model.layers.27.mlp.fc1.weight shape: [10240, 2560] +289: model.layers.27.mlp.fc2.bias shape: [2560] +290: model.layers.27.mlp.fc2.weight shape: [2560, 10240] +291: model.layers.27.self_attn.dense.bias shape: [2560] +292: model.layers.27.self_attn.dense.weight shape: [2560, 2560] +293: model.layers.27.self_attn.k_proj.bias shape: [2560] +294: model.layers.27.self_attn.k_proj.weight shape: [2560, 2560] +295: model.layers.27.self_attn.q_proj.bias shape: [2560] +296: model.layers.27.self_attn.q_proj.weight shape: [2560, 2560] +297: model.layers.27.self_attn.v_proj.bias shape: [2560] +298: model.layers.27.self_attn.v_proj.weight shape: [2560, 2560] +299: model.layers.28.input_layernorm.bias shape: [2560] +300: model.layers.28.input_layernorm.weight shape: [2560] +301: model.layers.28.mlp.fc1.bias shape: [10240] +302: model.layers.28.mlp.fc1.weight shape: [10240, 2560] +303: model.layers.28.mlp.fc2.bias shape: [2560] +304: model.layers.28.mlp.fc2.weight shape: [2560, 10240] +305: model.layers.28.self_attn.dense.bias shape: [2560] +306: model.layers.28.self_attn.dense.weight shape: [2560, 2560] +307: model.layers.28.self_attn.k_proj.bias shape: [2560] +308: model.layers.28.self_attn.k_proj.weight shape: [2560, 2560] +309: model.layers.28.self_attn.q_proj.bias shape: [2560] +310: model.layers.28.self_attn.q_proj.weight shape: [2560, 2560] +311: model.layers.28.self_attn.v_proj.bias shape: [2560] +312: model.layers.28.self_attn.v_proj.weight shape: [2560, 2560] +313: model.layers.29.input_layernorm.bias shape: [2560] +314: model.layers.29.input_layernorm.weight shape: [2560] +315: model.layers.29.mlp.fc1.bias shape: [10240] +316: model.layers.29.mlp.fc1.weight shape: [10240, 2560] +317: model.layers.29.mlp.fc2.bias shape: [2560] +318: model.layers.29.mlp.fc2.weight shape: [2560, 10240] +319: model.layers.29.self_attn.dense.bias shape: [2560] +320: model.layers.29.self_attn.dense.weight shape: [2560, 2560] +321: model.layers.29.self_attn.k_proj.bias shape: [2560] +322: model.layers.29.self_attn.k_proj.weight shape: [2560, 2560] +323: model.layers.29.self_attn.q_proj.bias shape: [2560] +324: model.layers.29.self_attn.q_proj.weight shape: [2560, 2560] +325: model.layers.29.self_attn.v_proj.bias shape: [2560] +326: model.layers.29.self_attn.v_proj.weight shape: [2560, 2560] +327: model.layers.3.input_layernorm.bias shape: [2560] +328: model.layers.3.input_layernorm.weight shape: [2560] +329: model.layers.3.mlp.fc1.bias shape: [10240] +330: model.layers.3.mlp.fc1.weight shape: [10240, 2560] +331: model.layers.3.mlp.fc2.bias shape: [2560] +332: model.layers.3.mlp.fc2.weight shape: [2560, 10240] +333: model.layers.3.self_attn.dense.bias shape: [2560] +334: model.layers.3.self_attn.dense.weight shape: [2560, 2560] +335: model.layers.3.self_attn.k_proj.bias shape: [2560] +336: model.layers.3.self_attn.k_proj.weight shape: [2560, 2560] +337: model.layers.3.self_attn.q_proj.bias shape: [2560] +338: model.layers.3.self_attn.q_proj.weight shape: [2560, 2560] +339: model.layers.3.self_attn.v_proj.bias shape: [2560] +340: model.layers.3.self_attn.v_proj.weight shape: [2560, 2560] +341: model.layers.30.input_layernorm.bias shape: [2560] +342: model.layers.30.input_layernorm.weight shape: [2560] +343: model.layers.30.mlp.fc1.bias shape: [10240] +344: model.layers.30.mlp.fc1.weight shape: [10240, 2560] +345: model.layers.30.mlp.fc2.bias shape: [2560] +346: model.layers.30.mlp.fc2.weight shape: [2560, 10240] +347: model.layers.30.self_attn.dense.bias shape: [2560] +348: model.layers.30.self_attn.dense.weight shape: [2560, 2560] +349: model.layers.30.self_attn.k_proj.bias shape: [2560] +350: model.layers.30.self_attn.k_proj.weight shape: [2560, 2560] +351: model.layers.30.self_attn.q_proj.bias shape: [2560] +352: model.layers.30.self_attn.q_proj.weight shape: [2560, 2560] +353: model.layers.30.self_attn.v_proj.bias shape: [2560] +354: model.layers.30.self_attn.v_proj.weight shape: [2560, 2560] +355: model.layers.31.input_layernorm.bias shape: [2560] +356: model.layers.31.input_layernorm.weight shape: [2560] +357: model.layers.31.mlp.fc1.bias shape: [10240] +358: model.layers.31.mlp.fc1.weight shape: [10240, 2560] +359: model.layers.31.mlp.fc2.bias shape: [2560] +360: model.layers.31.mlp.fc2.weight shape: [2560, 10240] +361: model.layers.31.self_attn.dense.bias shape: [2560] +362: model.layers.31.self_attn.dense.weight shape: [2560, 2560] +363: model.layers.31.self_attn.k_proj.bias shape: [2560] +364: model.layers.31.self_attn.k_proj.weight shape: [2560, 2560] +365: model.layers.31.self_attn.q_proj.bias shape: [2560] +366: model.layers.31.self_attn.q_proj.weight shape: [2560, 2560] +367: model.layers.31.self_attn.v_proj.bias shape: [2560] +368: model.layers.31.self_attn.v_proj.weight shape: [2560, 2560] +369: model.layers.4.input_layernorm.bias shape: [2560] +370: model.layers.4.input_layernorm.weight shape: [2560] +371: model.layers.4.mlp.fc1.bias shape: [10240] +372: model.layers.4.mlp.fc1.weight shape: [10240, 2560] +373: model.layers.4.mlp.fc2.bias shape: [2560] +374: model.layers.4.mlp.fc2.weight shape: [2560, 10240] +375: model.layers.4.self_attn.dense.bias shape: [2560] +376: model.layers.4.self_attn.dense.weight shape: [2560, 2560] +377: model.layers.4.self_attn.k_proj.bias shape: [2560] +378: model.layers.4.self_attn.k_proj.weight shape: [2560, 2560] +379: model.layers.4.self_attn.q_proj.bias shape: [2560] +380: model.layers.4.self_attn.q_proj.weight shape: [2560, 2560] +381: model.layers.4.self_attn.v_proj.bias shape: [2560] +382: model.layers.4.self_attn.v_proj.weight shape: [2560, 2560] +383: model.layers.5.input_layernorm.bias shape: [2560] +384: model.layers.5.input_layernorm.weight shape: [2560] +385: model.layers.5.mlp.fc1.bias shape: [10240] +386: model.layers.5.mlp.fc1.weight shape: [10240, 2560] +387: model.layers.5.mlp.fc2.bias shape: [2560] +388: model.layers.5.mlp.fc2.weight shape: [2560, 10240] +389: model.layers.5.self_attn.dense.bias shape: [2560] +390: model.layers.5.self_attn.dense.weight shape: [2560, 2560] +391: model.layers.5.self_attn.k_proj.bias shape: [2560] +392: model.layers.5.self_attn.k_proj.weight shape: [2560, 2560] +393: model.layers.5.self_attn.q_proj.bias shape: [2560] +394: model.layers.5.self_attn.q_proj.weight shape: [2560, 2560] +395: model.layers.5.self_attn.v_proj.bias shape: [2560] +396: model.layers.5.self_attn.v_proj.weight shape: [2560, 2560] +397: model.layers.6.input_layernorm.bias shape: [2560] +398: model.layers.6.input_layernorm.weight shape: [2560] +399: model.layers.6.mlp.fc1.bias shape: [10240] +400: model.layers.6.mlp.fc1.weight shape: [10240, 2560] +401: model.layers.6.mlp.fc2.bias shape: [2560] +402: model.layers.6.mlp.fc2.weight shape: [2560, 10240] +403: model.layers.6.self_attn.dense.bias shape: [2560] +404: model.layers.6.self_attn.dense.weight shape: [2560, 2560] +405: model.layers.6.self_attn.k_proj.bias shape: [2560] +406: model.layers.6.self_attn.k_proj.weight shape: [2560, 2560] +407: model.layers.6.self_attn.q_proj.bias shape: [2560] +408: model.layers.6.self_attn.q_proj.weight shape: [2560, 2560] +409: model.layers.6.self_attn.v_proj.bias shape: [2560] +410: model.layers.6.self_attn.v_proj.weight shape: [2560, 2560] +411: model.layers.7.input_layernorm.bias shape: [2560] +412: model.layers.7.input_layernorm.weight shape: [2560] +413: model.layers.7.mlp.fc1.bias shape: [10240] +414: model.layers.7.mlp.fc1.weight shape: [10240, 2560] +415: model.layers.7.mlp.fc2.bias shape: [2560] +416: model.layers.7.mlp.fc2.weight shape: [2560, 10240] +417: model.layers.7.self_attn.dense.bias shape: [2560] +418: model.layers.7.self_attn.dense.weight shape: [2560, 2560] +419: model.layers.7.self_attn.k_proj.bias shape: [2560] +420: model.layers.7.self_attn.k_proj.weight shape: [2560, 2560] +421: model.layers.7.self_attn.q_proj.bias shape: [2560] +422: model.layers.7.self_attn.q_proj.weight shape: [2560, 2560] +423: model.layers.7.self_attn.v_proj.bias shape: [2560] +424: model.layers.7.self_attn.v_proj.weight shape: [2560, 2560] +425: model.layers.8.input_layernorm.bias shape: [2560] +426: model.layers.8.input_layernorm.weight shape: [2560] +427: model.layers.8.mlp.fc1.bias shape: [10240] +428: model.layers.8.mlp.fc1.weight shape: [10240, 2560] +429: model.layers.8.mlp.fc2.bias shape: [2560] +430: model.layers.8.mlp.fc2.weight shape: [2560, 10240] +431: model.layers.8.self_attn.dense.bias shape: [2560] +432: model.layers.8.self_attn.dense.weight shape: [2560, 2560] +433: model.layers.8.self_attn.k_proj.bias shape: [2560] +434: model.layers.8.self_attn.k_proj.weight shape: [2560, 2560] +435: model.layers.8.self_attn.q_proj.bias shape: [2560] +436: model.layers.8.self_attn.q_proj.weight shape: [2560, 2560] +437: model.layers.8.self_attn.v_proj.bias shape: [2560] +438: model.layers.8.self_attn.v_proj.weight shape: [2560, 2560] +439: model.layers.9.input_layernorm.bias shape: [2560] +440: model.layers.9.input_layernorm.weight shape: [2560] +441: model.layers.9.mlp.fc1.bias shape: [10240] +442: model.layers.9.mlp.fc1.weight shape: [10240, 2560] +443: model.layers.9.mlp.fc2.bias shape: [2560] +444: model.layers.9.mlp.fc2.weight shape: [2560, 10240] +445: model.layers.9.self_attn.dense.bias shape: [2560] +446: model.layers.9.self_attn.dense.weight shape: [2560, 2560] +447: model.layers.9.self_attn.k_proj.bias shape: [2560] +448: model.layers.9.self_attn.k_proj.weight shape: [2560, 2560] +449: model.layers.9.self_attn.q_proj.bias shape: [2560] +450: model.layers.9.self_attn.q_proj.weight shape: [2560, 2560] +451: model.layers.9.self_attn.v_proj.bias shape: [2560] +452: model.layers.9.self_attn.v_proj.weight shape: [2560, 2560] diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi2Tests.TokenizerTest.approved.txt b/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi2Tests.TokenizerTest.approved.txt new file mode 100644 index 0000000000..7338548917 --- /dev/null +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Approvals/Phi2Tests.TokenizerTest.approved.txt @@ -0,0 +1,3 @@ +50256, 6090, 345, 2148, 2842, 284, 4483, 17790, 286, 35484, 290, 10441, 69, 50187, 30 +50256, 19457, 0, 3423, 389, 617, 2842, 284, 4483, 35484, 290, 10441, 69, 50187, 1978, 25, 352, 13, 40058, 290, 10441, 34711, 7209, 494, 25, 41198, 35484, 290, 10441, 69, 50187, 1978, 351, 617, 7545, 290, 12498, 13, 362, 13, 40058, 290, 10441, 34711, 20698, 25, 15561, 26790, 35484, 290, 10441, 69, 50187, 1978, 351, 617, 18873, 13135, 290, 12498, 13 +50256, 2061, 546, 18120, 281, 362, 87, 1343, 513, 796, 767, 16022, 30 diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj b/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj index c520ba7db3..ff13dfe880 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj @@ -1,7 +1,7 @@  - net6.0 + net6.0;net8.0 enable $(NoWarn);MSML_ExtendBaseTestClass enable diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Tests.cs b/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Tests.cs index 2a938de68d..bbcb0fa850 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Tests.cs +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Tests.cs @@ -12,6 +12,7 @@ using Xunit; namespace Microsoft.ML.GenAI.Phi.Tests; +[Collection("NoParallelization")] public class Phi2Tests { [Fact] diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs b/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs index b65d0bf995..4c7a395e61 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs @@ -13,6 +13,7 @@ namespace Microsoft.ML.GenAI.Phi.Tests; +[Collection("NoParallelization")] public class Phi3Tests { [Fact] From 1a77f8d15fc55d44fe4f8fc8165671a0c8b08bf0 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Wed, 17 Jul 2024 15:21:09 -0700 Subject: [PATCH 25/41] run test in net6.0 --- .../Microsoft.ML.GenAI.Core.Tests.csproj | 2 +- .../Microsoft.ML.GenAI.Phi.Tests.csproj | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj b/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj index 612fb7df5b..aa1c14c620 100644 --- a/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj +++ b/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj @@ -1,7 +1,7 @@  - net6.0;net8.0 + net6.0 enable $(NoWarn);MSML_ExtendBaseTestClass enable diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj b/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj index ff13dfe880..c520ba7db3 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj @@ -1,7 +1,7 @@  - net6.0;net8.0 + net6.0 enable $(NoWarn);MSML_ExtendBaseTestClass enable From e3e09e429033d5250afb3be90e5ee2401f398ab9 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Wed, 17 Jul 2024 16:32:52 -0700 Subject: [PATCH 26/41] use meta device --- test/Microsoft.ML.GenAI.Phi.Tests/Phi2Tests.cs | 6 ++++++ test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Tests.cs b/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Tests.cs index bbcb0fa850..b17a74c52e 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Tests.cs +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Tests.cs @@ -9,12 +9,18 @@ using FluentAssertions; using Microsoft.ML.GenAI.Core.Extension; using Microsoft.ML.Tokenizers; +using TorchSharp; using Xunit; namespace Microsoft.ML.GenAI.Phi.Tests; [Collection("NoParallelization")] public class Phi2Tests { + public Phi2Tests() + { + torch.set_default_device("meta"); + } + [Fact] [UseReporter(typeof(DiffReporter))] [UseApprovalSubdirectory("Approvals")] diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs b/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs index 4c7a395e61..7f1b5e5dec 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs @@ -9,6 +9,7 @@ using ApprovalTests.Reporters; using FluentAssertions; using Microsoft.ML.GenAI.Core.Extension; +using TorchSharp; using Xunit; namespace Microsoft.ML.GenAI.Phi.Tests; @@ -16,6 +17,11 @@ namespace Microsoft.ML.GenAI.Phi.Tests; [Collection("NoParallelization")] public class Phi3Tests { + public Phi3Tests() + { + torch.set_default_device("meta"); + } + [Fact] [UseReporter(typeof(DiffReporter))] [UseApprovalSubdirectory("Approvals")] From b316f193b4e39fdec127a87440994b4d91b0d1d5 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Wed, 17 Jul 2024 19:10:52 -0700 Subject: [PATCH 27/41] copy approval tests to output folder --- .../Microsoft.ML.GenAI.Phi.Tests.csproj | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj b/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj index c520ba7db3..479514a935 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj @@ -21,6 +21,12 @@ + + + PreserveNewest + + + From d6f0e617243ad707caf65dc0c72a5980003d2bb3 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Fri, 19 Jul 2024 10:57:55 -0700 Subject: [PATCH 28/41] set up approval test file location --- test/Microsoft.ML.GenAI.Phi.Tests/Phi2Tests.cs | 4 ++++ test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Tests.cs b/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Tests.cs index b17a74c52e..33402e73a0 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Tests.cs +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Phi2Tests.cs @@ -18,6 +18,10 @@ public class Phi2Tests { public Phi2Tests() { + if (Environment.GetEnvironmentVariable("HELIX_CORRELATION_ID") != null) + { + Approvals.UseAssemblyLocationForApprovedFiles(); + } torch.set_default_device("meta"); } diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs b/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs index 7f1b5e5dec..1200d79f9d 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs @@ -19,6 +19,11 @@ public class Phi3Tests { public Phi3Tests() { + if (Environment.GetEnvironmentVariable("HELIX_CORRELATION_ID") != null) + { + Approvals.UseAssemblyLocationForApprovedFiles(); + } + torch.set_default_device("meta"); } From 0bb6b989b49b1d2e86967e621804447b8293dfd4 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Tue, 23 Jul 2024 10:01:43 -0700 Subject: [PATCH 29/41] fix comment --- src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs | 1 + .../Phi2/{Phi2Tokenzier.cs => Phi2TokenizerHelper.cs} | 0 .../Phi3/{Phi3Tokenzier.cs => Phi3TokenizerHelper.cs} | 7 ------- 3 files changed, 1 insertion(+), 7 deletions(-) rename src/Microsoft.ML.GenAI.Phi/Phi2/{Phi2Tokenzier.cs => Phi2TokenizerHelper.cs} (100%) rename src/Microsoft.ML.GenAI.Phi/Phi3/{Phi3Tokenzier.cs => Phi3TokenizerHelper.cs} (86%) diff --git a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs index 1a4d443ccb..8c90a94cf7 100644 --- a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs +++ b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs @@ -249,6 +249,7 @@ public virtual IEnumerable GenerateStreaming( var tokens = this.Tokenizer.EncodeToTokens(x, out var _, false, false); return tokens + // Skip the first _ token automatically added by tokenizer .Where(t => t.Offset != (0, 0)) .Select(t => t.Id) .ToArray(); diff --git a/src/Microsoft.ML.GenAI.Phi/Phi2/Phi2Tokenzier.cs b/src/Microsoft.ML.GenAI.Phi/Phi2/Phi2TokenizerHelper.cs similarity index 100% rename from src/Microsoft.ML.GenAI.Phi/Phi2/Phi2Tokenzier.cs rename to src/Microsoft.ML.GenAI.Phi/Phi2/Phi2TokenizerHelper.cs diff --git a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3Tokenzier.cs b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3TokenizerHelper.cs similarity index 86% rename from src/Microsoft.ML.GenAI.Phi/Phi3/Phi3Tokenzier.cs rename to src/Microsoft.ML.GenAI.Phi/Phi3/Phi3TokenizerHelper.cs index 01a0af3061..dd54378892 100644 --- a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3Tokenzier.cs +++ b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3TokenizerHelper.cs @@ -19,13 +19,6 @@ public class Phi3TokenizerHelper private const int UserSymbolId = 32010; private const int AssistantSymbolId = 32001; private const int EndSymbolId = 32007; - private readonly Dictionary _specialTokenMap = new Dictionary - { - { SystemSymbol, SystemSymbolId }, - { UserSymbol, UserSymbolId }, - { AssistantSymbol, AssistantSymbolId }, - { EndSymbol, EndSymbolId } - }; public static LlamaTokenizer FromPretrained( string modelPath, From 405d162d38ee7cf00f7aa517bc717f0f61ec4a78 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Tue, 23 Jul 2024 10:04:41 -0700 Subject: [PATCH 30/41] rename to AddGenAITextGeneration and AddGenAIChatCompletion --- .../Phi3Mini/SemanticKernelSample.cs | 4 ++-- .../Extension/SemanticKernelExtension.cs | 4 ++-- src/Microsoft.ML.GenAI.Phi/Phi2/Phi2Extension.cs | 15 --------------- .../SemanticKernelTests.cs | 4 ++-- 4 files changed, 6 insertions(+), 21 deletions(-) delete mode 100644 src/Microsoft.ML.GenAI.Phi/Phi2/Phi2Extension.cs diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/SemanticKernelSample.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/SemanticKernelSample.cs index d82bd20704..a6f445b643 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/SemanticKernelSample.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/SemanticKernelSample.cs @@ -24,7 +24,7 @@ public static async Task RunChatCompletionSample() var kernel = Kernel.CreateBuilder() - .AddPhi3AsChatCompletion(pipeline) + .AddGenAIChatCompletion(pipeline) .Build(); var chatService = kernel.GetRequiredService(); var chatHistory = new ChatHistory(); @@ -53,7 +53,7 @@ public static async Task RunTextGenerationSample() var kernel = Kernel.CreateBuilder() - .AddPhi3AsTextGeneration(pipeline) + .AddGenAITextGeneration(pipeline) .Build(); var response = await kernel.InvokePromptAsync("Tell a joke"); diff --git a/src/Microsoft.ML.GenAI.Phi/Extension/SemanticKernelExtension.cs b/src/Microsoft.ML.GenAI.Phi/Extension/SemanticKernelExtension.cs index c2ef497d64..ace7a7b425 100644 --- a/src/Microsoft.ML.GenAI.Phi/Extension/SemanticKernelExtension.cs +++ b/src/Microsoft.ML.GenAI.Phi/Extension/SemanticKernelExtension.cs @@ -13,7 +13,7 @@ namespace Microsoft.ML.GenAI.Phi.Extension; public static class SemanticKernelExtension { - public static IKernelBuilder AddPhi3AsChatCompletion( + public static IKernelBuilder AddGenAIChatCompletion( this IKernelBuilder builder, ICausalLMPipeline pipeline) { @@ -22,7 +22,7 @@ public static IKernelBuilder AddPhi3AsChatCompletion( return builder; } - public static IKernelBuilder AddPhi3AsTextGeneration( + public static IKernelBuilder AddGenAITextGeneration( this IKernelBuilder builder, ICausalLMPipeline pipeline) { diff --git a/src/Microsoft.ML.GenAI.Phi/Phi2/Phi2Extension.cs b/src/Microsoft.ML.GenAI.Phi/Phi2/Phi2Extension.cs deleted file mode 100644 index 9c3043a03b..0000000000 --- a/src/Microsoft.ML.GenAI.Phi/Phi2/Phi2Extension.cs +++ /dev/null @@ -1,15 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; - -namespace Microsoft.ML.GenAI.Phi.Extension; - -public static class Phi2Extension -{ -} diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/SemanticKernelTests.cs b/test/Microsoft.ML.GenAI.Phi.Tests/SemanticKernelTests.cs index 3bed3b5a79..98359a8722 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/SemanticKernelTests.cs +++ b/test/Microsoft.ML.GenAI.Phi.Tests/SemanticKernelTests.cs @@ -35,7 +35,7 @@ public async Task ItAddPhi3CausalLMChatCompletionServiceTestAsync() .Returns((string prompt, int maxLen, float temperature, float topP, string[] stopSequences) => "hello"); var kernel = Kernel.CreateBuilder() - .AddPhi3AsChatCompletion(pipeline) + .AddGenAIChatCompletion(pipeline) .Build(); var chatService = kernel.Services.GetRequiredService(); @@ -71,7 +71,7 @@ public async Task ItAddPhi3CausalLMTextGenerationServiceTestAsync() .Returns((string prompt, int maxLen, float temperature, float topP, string[] stopSequences) => "hello"); var kernel = Kernel.CreateBuilder() - .AddPhi3AsTextGeneration(pipeline) + .AddGenAITextGeneration(pipeline) .Build(); var response = await kernel.InvokePromptAsync("test"); From a1b03690291da28509de5f8655bbc04271140759 Mon Sep 17 00:00:00 2001 From: Xiaoyun Zhang Date: Fri, 26 Jul 2024 08:42:20 -0700 Subject: [PATCH 31/41] Update job-template.yml --- build/ci/job-template.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build/ci/job-template.yml b/build/ci/job-template.yml index 8a78ea548b..3270a42615 100644 --- a/build/ci/job-template.yml +++ b/build/ci/job-template.yml @@ -121,7 +121,7 @@ jobs: - ${{ if eq(parameters.nightlyBuild, 'false') }}: - ${{ if eq(parameters.innerLoop, 'false') }}: - ${{ if and(eq(parameters.runSpecific, 'false'), eq(parameters.useVSTestTask, 'false')) }}: - - script: set PATH=%PATH%;%USERPROFILE%\.nuget\packages\libtorch-cpu-win-x64\2.2.0.1\runtimes\win-x64\native;%USERPROFILE%\.nuget\packages\torchsharp\0.102.5\runtimes\win-x64\native & ${{ parameters.buildScript }} /p:Build=false -configuration $(_configuration) /p:TargetArchitecture=${{ parameters.architecture }} /p:TestArchitectures=${{ parameters.architecture }} -test -integrationTest /p:Coverage=${{ parameters.codeCoverage }} $(testTargetFramework) + - script: set PATH=%PATH%;%USERPROFILE%\.nuget\packages\libtorch-cpu-win-x64\2.2.1.1\runtimes\win-x64\native;%USERPROFILE%\.nuget\packages\torchsharp\0.102.7\runtimes\win-x64\native & ${{ parameters.buildScript }} /p:Build=false -configuration $(_configuration) /p:TargetArchitecture=${{ parameters.architecture }} /p:TestArchitectures=${{ parameters.architecture }} -test -integrationTest /p:Coverage=${{ parameters.codeCoverage }} $(testTargetFramework) displayName: Run All Tests. - ${{ if and(eq(parameters.runSpecific, 'true'), eq(parameters.useVSTestTask, 'false')) }}: - script: ${{ parameters.buildScript }} /p:Build=false -configuration $(_configuration) /p:TargetArchitecture=${{ parameters.architecture }} /p:TestArchitectures=${{ parameters.architecture }} -test -integrationTest /p:TestRunnerAdditionalArguments='-trait$(spaceValue)Category=RunSpecificTest' /p:Coverage=${{ parameters.codeCoverage }} $(testTargetFramework) From 6b3b46fbf7142dd316976e10a677e2917d49ba29 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Mon, 29 Jul 2024 10:49:54 -0700 Subject: [PATCH 32/41] add mit license --- THIRD-PARTY-NOTICES.TXT | 26 +++++++++++++++++++ .../Microsoft.ML.GenAI.Core.Tests.csproj | 6 ++--- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/THIRD-PARTY-NOTICES.TXT b/THIRD-PARTY-NOTICES.TXT index 650a7bd53d..0a86247f4e 100644 --- a/THIRD-PARTY-NOTICES.TXT +++ b/THIRD-PARTY-NOTICES.TXT @@ -171,3 +171,29 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + +License notice for Torchsharp.PyBridge +------------------------------------------ +https://github.com/shaltielshmid/TorchSharp.PyBridge/blob/main/LICENSE + +MIT License + +Copyright (c) 2023 shaltielshmid + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj b/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj index aa1c14c620..112d3a7b29 100644 --- a/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj +++ b/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj @@ -22,8 +22,8 @@ - - - + + + From c165b5fdd42a836313b76be96efdb6c3a68da4a7 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Mon, 29 Jul 2024 11:07:52 -0700 Subject: [PATCH 33/41] add reference --- .../Microsoft.ML.GenAI.Core.Tests.csproj | 1 + .../Microsoft.ML.GenAI.Phi.Tests.csproj | 1 + 2 files changed, 2 insertions(+) diff --git a/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj b/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj index 112d3a7b29..fff3befa01 100644 --- a/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj +++ b/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj @@ -15,6 +15,7 @@ + diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj b/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj index 479514a935..b19b6ec7a7 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj @@ -19,6 +19,7 @@ + From b837269ddfed0d84e5d89965bf032092ad7ab427 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Mon, 29 Jul 2024 14:08:51 -0700 Subject: [PATCH 34/41] bump code coverage version --- eng/Versions.props | 4 ++-- src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/eng/Versions.props b/eng/Versions.props index 6dca3cb3bd..c7636793dd 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -70,8 +70,8 @@ 2.2.1.1 1.12.4 - 3.1.2 - 3.1.2 + 6.0.2 + 6.0.0 3.3.1 4.5.0 4.3.6 diff --git a/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj b/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj index 0c51883e6c..af8b6aed6e 100644 --- a/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj +++ b/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj @@ -9,12 +9,13 @@ + + - From 3020ebd870faaed8dc64b1aa4837e3747e311ea6 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Mon, 29 Jul 2024 17:01:51 -0700 Subject: [PATCH 35/41] add true --- .../Microsoft.ML.GenAI.Core.Tests.csproj | 1 + .../Microsoft.ML.GenAI.Phi.Tests.csproj | 1 + 2 files changed, 2 insertions(+) diff --git a/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj b/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj index fff3befa01..9960611dcf 100644 --- a/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj +++ b/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj @@ -5,6 +5,7 @@ enable $(NoWarn);MSML_ExtendBaseTestClass enable + true diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj b/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj index b19b6ec7a7..dbe744ab12 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj @@ -5,6 +5,7 @@ enable $(NoWarn);MSML_ExtendBaseTestClass enable + true From 077d36675f5845ad805d47063c07e4ef7c41004e Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Mon, 29 Jul 2024 17:08:34 -0700 Subject: [PATCH 36/41] add runtime package --- src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj | 3 +++ src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj | 3 +++ 2 files changed, 6 insertions(+) diff --git a/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj b/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj index 290b2da925..da47fd8767 100644 --- a/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj +++ b/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj @@ -10,6 +10,9 @@ + + + diff --git a/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj b/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj index af8b6aed6e..dd783d70d5 100644 --- a/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj +++ b/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj @@ -11,6 +11,9 @@ + + + From cd3b20ef4036a1567de428abdecef0b9e3c52484 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Tue, 30 Jul 2024 00:26:50 -0700 Subject: [PATCH 37/41] remove flag --- .../Microsoft.ML.GenAI.Core.Tests.csproj | 1 - .../Microsoft.ML.GenAI.Phi.Tests.csproj | 1 - 2 files changed, 2 deletions(-) diff --git a/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj b/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj index 9960611dcf..fff3befa01 100644 --- a/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj +++ b/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj @@ -5,7 +5,6 @@ enable $(NoWarn);MSML_ExtendBaseTestClass enable - true diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj b/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj index dbe744ab12..b19b6ec7a7 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj @@ -5,7 +5,6 @@ enable $(NoWarn);MSML_ExtendBaseTestClass enable - true From 62576c4aa73bb69545a0853244e5f26ffc6681ca Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Tue, 30 Jul 2024 00:28:16 -0700 Subject: [PATCH 38/41] add flag --- src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj | 5 +++++ src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj b/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj index da47fd8767..d6acb3dd81 100644 --- a/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj +++ b/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj @@ -10,6 +10,11 @@ + + + + + diff --git a/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj b/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj index dd783d70d5..d40f6f9c70 100644 --- a/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj +++ b/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj @@ -11,6 +11,11 @@ + + + + + From d91ec0a0cd85520304fa28c99c191e786e9aee86 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Tue, 30 Jul 2024 09:25:19 -0700 Subject: [PATCH 39/41] fix build error --- src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj | 2 +- src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj b/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj index d6acb3dd81..6daadff172 100644 --- a/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj +++ b/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj @@ -13,7 +13,7 @@ - + diff --git a/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj b/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj index d40f6f9c70..2d0ec3b4a7 100644 --- a/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj +++ b/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj @@ -14,7 +14,7 @@ - + From e7d0fdeed41074d6c69214a3c0de0deae2ffa77e Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Tue, 30 Jul 2024 10:58:28 -0700 Subject: [PATCH 40/41] update --- .../Microsoft.ML.GenAI.Core.Tests.csproj | 1 + .../Microsoft.ML.GenAI.Phi.Tests.csproj | 1 + 2 files changed, 2 insertions(+) diff --git a/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj b/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj index fff3befa01..9960611dcf 100644 --- a/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj +++ b/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj @@ -5,6 +5,7 @@ enable $(NoWarn);MSML_ExtendBaseTestClass enable + true diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj b/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj index b19b6ec7a7..dbe744ab12 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj @@ -5,6 +5,7 @@ enable $(NoWarn);MSML_ExtendBaseTestClass enable + true From a97afdfb2ece7921c1d498d8e621bfe5ce25d94b Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Tue, 30 Jul 2024 12:33:51 -0700 Subject: [PATCH 41/41] update --- src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj | 5 ++--- src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj b/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj index 6daadff172..dfb64082fb 100644 --- a/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj +++ b/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj @@ -11,14 +11,13 @@ - + - + --> diff --git a/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj b/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj index 2d0ec3b4a7..a9556443dd 100644 --- a/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj +++ b/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj @@ -14,12 +14,11 @@ - - +