From 6e005eb5bf9c965afc7f582e7d8f8952db6b57da Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Tue, 3 Sep 2024 11:17:23 -0700 Subject: [PATCH 01/17] add mistral and tests --- Microsoft.ML.sln | 24 ++++- .../Microsoft.ML.GenAI.Core.csproj | 4 +- .../Microsoft.ML.GenAI.Mistral.csproj | 24 +++++ .../MistralConfig.cs | 88 +++++++++++++++++ src/Microsoft.ML.GenAI.Mistral/MistralMLP.cs | 45 +++++++++ .../MistralTokenizerHelper.cs | 99 +++++++++++++++++++ ...stral_V0_3Tests.TokenizerTest.approved.txt | 2 + .../Microsoft.ML.GenAI.Mistral.Tests.csproj | 44 +++++++++ .../Mistral_V0_3Tests.cs | 48 +++++++++ 9 files changed, 376 insertions(+), 2 deletions(-) create mode 100644 src/Microsoft.ML.GenAI.Mistral/Microsoft.ML.GenAI.Mistral.csproj create mode 100644 src/Microsoft.ML.GenAI.Mistral/MistralConfig.cs create mode 100644 src/Microsoft.ML.GenAI.Mistral/MistralMLP.cs create mode 100644 src/Microsoft.ML.GenAI.Mistral/MistralTokenizerHelper.cs create mode 100644 test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_V0_3Tests.TokenizerTest.approved.txt create mode 100644 test/Microsoft.ML.GenAI.Mistral.Tests/Microsoft.ML.GenAI.Mistral.Tests.csproj create mode 100644 test/Microsoft.ML.GenAI.Mistral.Tests/Mistral_V0_3Tests.cs diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln index c55f5797f2..00635886a1 100644 --- a/Microsoft.ML.sln +++ b/Microsoft.ML.sln @@ -188,7 +188,11 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.GenAI.Core.Tes EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.GenAI.LLaMA", "src\Microsoft.ML.GenAI.LLaMA\Microsoft.ML.GenAI.LLaMA.csproj", "{0AA6D5CB-195F-457A-8792-4221E76E6C44}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.ML.GenAI.LLaMA.Tests", "test\Microsoft.ML.GenAI.LLaMA.Tests\Microsoft.ML.GenAI.LLaMA.Tests.csproj", "{D202353D-6FAF-4263-9A01-BDCFBC92391F}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.GenAI.LLaMA.Tests", "test\Microsoft.ML.GenAI.LLaMA.Tests\Microsoft.ML.GenAI.LLaMA.Tests.csproj", "{D202353D-6FAF-4263-9A01-BDCFBC92391F}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.ML.GenAI.Mistral", "src\Microsoft.ML.GenAI.Mistral\Microsoft.ML.GenAI.Mistral.csproj", "{2729CC66-7743-442B-B3A5-1F4F27F044A5}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.ML.GenAI.Mistral.Tests", "test\Microsoft.ML.GenAI.Mistral.Tests\Microsoft.ML.GenAI.Mistral.Tests.csproj", "{49264202-C90A-43F6-8C30-BDAEF2F1465A}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution @@ -898,6 +902,22 @@ Global {D202353D-6FAF-4263-9A01-BDCFBC92391F}.Release|Any CPU.Build.0 = Release|Any CPU {D202353D-6FAF-4263-9A01-BDCFBC92391F}.Release|x64.ActiveCfg = Release|Any CPU {D202353D-6FAF-4263-9A01-BDCFBC92391F}.Release|x64.Build.0 = Release|Any CPU + {2729CC66-7743-442B-B3A5-1F4F27F044A5}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {2729CC66-7743-442B-B3A5-1F4F27F044A5}.Debug|Any CPU.Build.0 = Debug|Any CPU + {2729CC66-7743-442B-B3A5-1F4F27F044A5}.Debug|x64.ActiveCfg = Debug|Any CPU + {2729CC66-7743-442B-B3A5-1F4F27F044A5}.Debug|x64.Build.0 = Debug|Any CPU + {2729CC66-7743-442B-B3A5-1F4F27F044A5}.Release|Any CPU.ActiveCfg = Release|Any CPU + {2729CC66-7743-442B-B3A5-1F4F27F044A5}.Release|Any CPU.Build.0 = Release|Any CPU + {2729CC66-7743-442B-B3A5-1F4F27F044A5}.Release|x64.ActiveCfg = Release|Any CPU + {2729CC66-7743-442B-B3A5-1F4F27F044A5}.Release|x64.Build.0 = Release|Any CPU + {49264202-C90A-43F6-8C30-BDAEF2F1465A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {49264202-C90A-43F6-8C30-BDAEF2F1465A}.Debug|Any CPU.Build.0 = Debug|Any CPU + {49264202-C90A-43F6-8C30-BDAEF2F1465A}.Debug|x64.ActiveCfg = Debug|Any CPU + {49264202-C90A-43F6-8C30-BDAEF2F1465A}.Debug|x64.Build.0 = Debug|Any CPU + {49264202-C90A-43F6-8C30-BDAEF2F1465A}.Release|Any CPU.ActiveCfg = Release|Any CPU + {49264202-C90A-43F6-8C30-BDAEF2F1465A}.Release|Any CPU.Build.0 = Release|Any CPU + {49264202-C90A-43F6-8C30-BDAEF2F1465A}.Release|x64.ActiveCfg = Release|Any CPU + {49264202-C90A-43F6-8C30-BDAEF2F1465A}.Release|x64.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -991,6 +1011,8 @@ Global {14AB0804-D4CE-4634-B544-5A8587620783} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} {0AA6D5CB-195F-457A-8792-4221E76E6C44} = {09EADF06-BE25-4228-AB53-95AE3E15B530} {D202353D-6FAF-4263-9A01-BDCFBC92391F} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} + {2729CC66-7743-442B-B3A5-1F4F27F044A5} = {09EADF06-BE25-4228-AB53-95AE3E15B530} + {49264202-C90A-43F6-8C30-BDAEF2F1465A} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D} diff --git a/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj b/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj index 8745b81c6d..64087de176 100644 --- a/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj +++ b/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj @@ -20,9 +20,11 @@ + - + + diff --git a/src/Microsoft.ML.GenAI.Mistral/Microsoft.ML.GenAI.Mistral.csproj b/src/Microsoft.ML.GenAI.Mistral/Microsoft.ML.GenAI.Mistral.csproj new file mode 100644 index 0000000000..5b0cb0acc0 --- /dev/null +++ b/src/Microsoft.ML.GenAI.Mistral/Microsoft.ML.GenAI.Mistral.csproj @@ -0,0 +1,24 @@ + + + + net6.0;net8.0 + enable + enable + + + + + + + + + + + + + + + + + + diff --git a/src/Microsoft.ML.GenAI.Mistral/MistralConfig.cs b/src/Microsoft.ML.GenAI.Mistral/MistralConfig.cs new file mode 100644 index 0000000000..a1a66a5585 --- /dev/null +++ b/src/Microsoft.ML.GenAI.Mistral/MistralConfig.cs @@ -0,0 +1,88 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Text.Json.Serialization; +using Microsoft.ML.GenAI.Core; +using TorchSharp; + +namespace Microsoft.ML.GenAI.Mistral; + +public class MistralConfig +{ + public MistralConfig() + { + this.AttentionBias = false; + this.AttentionDropout = 0.0; + this.HiddenAct = "silu"; + this.HiddenSize = 4096; + this.InitializerRange = 0.02; + this.IntermediateSize = 14336; + this.MaxPositionEmbeddings = 131072; + this.MlpBias = false; + this.NumAttentionHeads = 32; + this.NumHiddenLayers = 32; + this.NumKeyValueHeads = 8; + this.PretrainingTp = 1; + this.RmsNormEps = 1e-05f; + this.RopeScaling = new RopeScalingConfig(); + this.RopeTheta = 500000.0; + this.TieWordEmbeddings = false; + this.VocabSize = 128256; + this.AttnImplementation = "eager"; + this.DType = torch.ScalarType.BFloat16; + } + + [JsonPropertyName("attention_bias")] + public bool AttentionBias { get; set; } + + [JsonPropertyName("attention_dropout")] + public double AttentionDropout { get; set; } + + [JsonPropertyName("hidden_act")] + public string HiddenAct { get; set; } + + [JsonPropertyName("hidden_size")] + public int HiddenSize { get; set; } + + [JsonPropertyName("initializer_range")] + public double InitializerRange { get; set; } + + [JsonPropertyName("intermediate_size")] + public int IntermediateSize { get; set; } + + [JsonPropertyName("max_position_embeddings")] + public int MaxPositionEmbeddings { get; set; } + + [JsonPropertyName("mlp_bias")] + public bool MlpBias { get; set; } + + [JsonPropertyName("num_attention_heads")] + public int NumAttentionHeads { get; set; } + + [JsonPropertyName("num_hidden_layers")] + public int NumHiddenLayers { get; set; } + + [JsonPropertyName("num_key_value_heads")] + public int NumKeyValueHeads { get; set; } + + [JsonPropertyName("pretraining_tp")] + public int PretrainingTp { get; set; } + + [JsonPropertyName("rms_norm_eps")] + public float RmsNormEps { get; set; } + + public RopeScalingConfig RopeScaling { get; set; } + + [JsonPropertyName("rope_theta")] + public double RopeTheta { get; set; } + + [JsonPropertyName("tie_word_embeddings")] + public bool TieWordEmbeddings { get; set; } + + [JsonPropertyName("vocab_size")] + public int VocabSize { get; set; } + public int? PadTokenId { get; set; } + public torch.ScalarType DType { get; set; } + public string AttnImplementation { get; set; } +} diff --git a/src/Microsoft.ML.GenAI.Mistral/MistralMLP.cs b/src/Microsoft.ML.GenAI.Mistral/MistralMLP.cs new file mode 100644 index 0000000000..347ee625ee --- /dev/null +++ b/src/Microsoft.ML.GenAI.Mistral/MistralMLP.cs @@ -0,0 +1,45 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.GenAI.Core; +using TorchSharp; +using static TorchSharp.torch; + +namespace Microsoft.ML.GenAI.Mistral.Module; +#pragma warning disable MSML_GeneralName // This name should be PascalCased +internal class MistralMLP : torch.nn.Module +#pragma warning restore MSML_GeneralName // This name should be PascalCased +{ + private readonly int _intermediateSize; + private readonly int _hiddenSize; +#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format + private readonly QuantizedLinear gate_proj; + private readonly QuantizedLinear up_proj; + private readonly QuantizedLinear down_proj; + private readonly torch.nn.Module act_fn; +#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format + + public MistralMLP(MistralConfig config) + : base(nameof(MistralMLP)) + { + this._hiddenSize = config.HiddenSize; + this._intermediateSize = config.IntermediateSize; + var hiddenAct = config.HiddenAct; + this.gate_proj = new QuantizedLinear(this._hiddenSize, this._intermediateSize, hasBias: false, dtype: config.DType); + this.up_proj = new QuantizedLinear(this._hiddenSize, this._intermediateSize, hasBias: false, dtype: config.DType); + this.down_proj = new QuantizedLinear(this._intermediateSize, this._hiddenSize, hasBias: false, dtype: config.DType); + this.RegisterComponents(); + this.act_fn = Core.Utils.GetActivation(hiddenAct); + } + +#pragma warning disable MSML_GeneralName // This name should be PascalCased + public override Tensor forward(Tensor input) +#pragma warning restore MSML_GeneralName // This name should be PascalCased + { + using var input1 = this.gate_proj.forward(input); + using var input2 = this.act_fn.forward(input1); + using var input3 = input2 * this.up_proj.forward(input); + return this.down_proj.forward(input3); + } +} diff --git a/src/Microsoft.ML.GenAI.Mistral/MistralTokenizerHelper.cs b/src/Microsoft.ML.GenAI.Mistral/MistralTokenizerHelper.cs new file mode 100644 index 0000000000..4da9091030 --- /dev/null +++ b/src/Microsoft.ML.GenAI.Mistral/MistralTokenizerHelper.cs @@ -0,0 +1,99 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.ML.Tokenizers; + +namespace Microsoft.ML.GenAI.Mistral; + +public class MistralTokenizerHelper +{ + private const string UnknownSymbol = ""; + private const int UnknownSymbolId = 0; + private const string StartSymbol = ""; + private const int StartSymbolId = 1; + private const string EndSymbol = ""; + private const int EndSymbolId = 2; + private const string StartInstructionSymbol = "[INST]"; + private const int StartInstructionSymbolId = 3; + private const string EndInstructionSymbol = "[/INST]"; + private const int EndInstructionSymbolId = 4; + private const string ToolCallSymbol = "[TOOL_CALLS]"; + private const int ToolCallSymbolId = 5; + private const string StartAvailableToolsSymbol = "[AVAILABLE_TOOLS]"; + private const int StartAvailableToolsSymbolId = 6; + private const string EndAvailableToolsSymbol = "[/AVAILABLE_TOOLS]"; + private const int EndAvailableToolsSymbolId = 7; + private const string StartToolResultSymbol = "[TOOL_RESULTS]"; + private const int StartToolResultSymbolId = 8; + private const string EndToolResultSymbol = "[/TOOL_RESULTS]"; + private const int EndToolResultSymbolId = 9; + + public static LlamaTokenizer FromPretrained( + string modelWeightFolder, + string modelName = "tokenizer.model.v3", + string unknownSymbol = UnknownSymbol, + int unknownSymbolId = 0, + string startSymbol = StartSymbol, + int startSymbolId = 1, + string endSymbol = EndSymbol, + int endSymbolId = 2, + string startInstructionSymbol = StartInstructionSymbol, + int startInstructionSymbolId = 3, + string endInstructionSymbol = EndInstructionSymbol, + int endInstructionSymbolId = 4, + string toolCallSymbol = ToolCallSymbol, + int toolCallSymbolId = 5, + string startAvailableToolsSymbol = StartAvailableToolsSymbol, + int startAvailableToolsSymbolId = 6, + string endAvailableToolsSymbol = EndAvailableToolsSymbol, + int endAvailableToolsSymbolId = 7, + string startToolResultSymbol = StartToolResultSymbol, + int startToolResultSymbolId = 8, + string endToolResultSymbol = EndToolResultSymbol, + int endToolResultSymbolId = 9, + bool addPrecedingSpace = true) + { + var specialTokens = new Dictionary + { + { unknownSymbol, unknownSymbolId }, + { startSymbol, startSymbolId }, + { endSymbol, endSymbolId }, + { startInstructionSymbol, startInstructionSymbolId }, + { endInstructionSymbol, endInstructionSymbolId }, + { toolCallSymbol, toolCallSymbolId }, + { startAvailableToolsSymbol, startAvailableToolsSymbolId }, + { endAvailableToolsSymbol, endAvailableToolsSymbolId }, + { startToolResultSymbol, startToolResultSymbolId }, + { endToolResultSymbol, endToolResultSymbolId } + }; + + return FromPretrained( + modelWeightFolder, + modelName, + specialTokens, + addPrecedingSpace); + } + + public static LlamaTokenizer FromPretrained( + string modelWeightFolder, + string modelName, + Dictionary specialTokens, + bool addPrecedingSpace = true) + { + var modelPath = Path.Combine(modelWeightFolder, modelName); + var modelStream = File.OpenRead(modelPath); + + var llamaTokenizer = LlamaTokenizer.Create( + modelStream, + addPrecedingSpace, + specialTokens: specialTokens); + + return llamaTokenizer; + } +} diff --git a/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_V0_3Tests.TokenizerTest.approved.txt b/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_V0_3Tests.TokenizerTest.approved.txt new file mode 100644 index 0000000000..fc8562c9e9 --- /dev/null +++ b/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_V0_3Tests.TokenizerTest.approved.txt @@ -0,0 +1,2 @@ + [{"type": "function", "function": {"name": "get_current_weather", "description": "Get the current weather", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "format": {"type": "string", "enum": ["celsius", "fahrenheit"], "description": "The temperature unit to use. Infer this from the users location."}}, "required": ["location", "format"]}}}] What's the weather like in Paris? [{"name": "get_current_temperature", "arguments": {"location": "Paris, France", "unit": "celsius"}, "id": "9Ae3bDc2F"}] {"content": 22.0, "call_id": "9Ae3bDc2F"} The current temperature in Paris is 22.0 degrees celsius. +1, 1, 6, 1501, 7567, 1891, 2032, 1113, 3396, 1316, 1113, 3396, 2032, 10598, 1629, 2032, 1113, 1295, 29498, 3790, 29498, 1537, 1991, 1316, 1113, 7286, 2032, 1113, 2226, 1040, 2636, 8854, 1316, 1113, 12206, 2032, 10598, 1891, 2032, 1113, 3582, 1316, 1113, 11491, 2032, 10598, 3501, 2032, 10598, 1891, 2032, 1113, 2195, 1316, 1113, 7286, 2032, 1113, 1782, 3758, 1072, 2433, 29493, 1085, 29491, 29489, 29491, 4420, 10454, 29493, 10229, 8474, 1113, 4530, 2032, 10598, 1891, 2032, 1113, 2195, 1316, 1113, 10825, 2032, 8135, 29485, 1958, 3938, 1316, 1113, 29490, 19425, 13075, 9651, 1113, 7286, 2032, 1113, 1782, 8409, 5796, 1066, 1706, 29491, 1328, 1410, 1224, 1245, 1040, 6211, 5491, 1379, 11549, 1113, 11661, 2032, 8135, 3501, 1316, 1113, 4530, 3010, 14879, 29561, 7, 3, 2592, 29510, 29481, 1040, 8854, 1505, 1065, 6233, 29572, 4, 5, 1501, 7567, 1629, 2032, 1113, 1295, 29498, 3790, 29498, 29475, 17329, 1316, 1113, 17452, 2032, 10598, 3501, 2032, 1113, 4684, 1046, 29493, 5611, 1316, 1113, 6074, 2032, 1113, 29485, 1958, 3938, 8474, 1113, 1081, 2032, 1113, 29542, 29509, 29474, 29538, 29494, 29525, 29485, 29518, 29533, 29507, 10925, 2, 8, 10598, 4557, 2032, 29473, 29518, 29518, 29491, 29502, 29493, 1113, 3613, 29498, 1081, 2032, 1113, 29542, 29509, 29474, 29538, 29494, 29525, 29485, 29518, 29533, 18163, 9, 1183, 2636, 8409, 1065, 6233, 1117, 29473, 29518, 29518, 29491, 29502, 11950, 1045, 1958, 3938, 29491, 2 diff --git a/test/Microsoft.ML.GenAI.Mistral.Tests/Microsoft.ML.GenAI.Mistral.Tests.csproj b/test/Microsoft.ML.GenAI.Mistral.Tests/Microsoft.ML.GenAI.Mistral.Tests.csproj new file mode 100644 index 0000000000..da2dde0bfc --- /dev/null +++ b/test/Microsoft.ML.GenAI.Mistral.Tests/Microsoft.ML.GenAI.Mistral.Tests.csproj @@ -0,0 +1,44 @@ + + + + net6.0 + enable + $(NoWarn);MSML_ExtendBaseTestClass + enable + true + + + + + + + + + + + + + + + + + + + + + + + + + PreserveNewest + + + + + + + + + + + diff --git a/test/Microsoft.ML.GenAI.Mistral.Tests/Mistral_V0_3Tests.cs b/test/Microsoft.ML.GenAI.Mistral.Tests/Mistral_V0_3Tests.cs new file mode 100644 index 0000000000..63017258bf --- /dev/null +++ b/test/Microsoft.ML.GenAI.Mistral.Tests/Mistral_V0_3Tests.cs @@ -0,0 +1,48 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using ApprovalTests.Namers; +using ApprovalTests.Reporters; +using ApprovalTests; +using Xunit; + +namespace Microsoft.ML.GenAI.Mistral.Tests; + +public class Mistral_V0_3Tests +{ + [Fact] + [UseReporter(typeof(DiffReporter))] + [UseApprovalSubdirectory("Approvals")] + public void TokenizerTest() + { + var modelWeightFolder = Path.Join("C:\\Users\\xiaoyuz\\source\\repos\\Mistral-7B-Instruct-v0.3"); + var tokenizer = MistralTokenizerHelper.FromPretrained(modelWeightFolder); + + var messages = new string[] + { + // system : You are a helpful assistant that can answer questions about the weather. + // tool: [get-weather-tool-call] + // user : What's the weather like in Paris? + // assistant: // get-weather-tool-call + // tool: get-weather-tool-call-result + // assistant: The current temperature in Paris is 22.0 degrees celsius. + """ + [AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_current_weather", "description": "Get the current weather", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "format": {"type": "string", "enum": ["celsius", "fahrenheit"], "description": "The temperature unit to use. Infer this from the users location."}}, "required": ["location", "format"]}}}][/AVAILABLE_TOOLS][INST] What's the weather like in Paris?[/INST][TOOL_CALLS] [{"name": "get_current_temperature", "arguments": {"location": "Paris, France", "unit": "celsius"}, "id": "9Ae3bDc2F"}][TOOL_RESULTS] {"content": 22.0, "call_id": "9Ae3bDc2F"}[/TOOL_RESULTS] The current temperature in Paris is 22.0 degrees celsius. + """ + }; + + var sb = new StringBuilder(); + foreach (var message in messages) + { + var tokenizeIds = tokenizer.EncodeToIds(message, true, false); + var decodeToString = tokenizer.Decode(tokenizeIds); + sb.AppendLine(decodeToString); + var tokenizedStr = string.Join(", ", tokenizeIds.Select(x => x.ToString())); + + sb.AppendLine(tokenizedStr); + } + Approvals.Verify(sb.ToString()); + } +} From 7bd933849029e5209829da0a1bd68e299056547c Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Tue, 3 Sep 2024 13:55:48 -0700 Subject: [PATCH 02/17] add test and sample --- .../Microsoft.ML.GenAI.Samples.csproj | 1 + .../Mistral/Mistral_7B_Instruct.cs | 47 +++ .../Microsoft.ML.GenAI.Samples/Program.cs | 3 +- .../Utility/IChatTemplateBuilder.cs | 2 +- src/Microsoft.ML.GenAI.Core/Utils.cs | 15 + .../Llama3_1ChatTemplateBuilder.cs | 2 +- .../Module/LlamaModel.cs | 6 - src/Microsoft.ML.GenAI.LLaMA/Utils.cs | 27 -- .../MistralCausalLMAgent.cs | 89 ++++++ .../MistralConfig.cs | 30 +- .../MistralDecoderLayer.cs | 148 +++++++++ .../MistralForCausalLM.cs | 121 ++++++++ .../MistralModel.cs | 148 +++++++++ .../Mistral_7B_0_3ChatTemplateBuilder.cs | 79 +++++ .../Config/mistral-7B-instruct-v0.3.json | 21 ++ src/Microsoft.ML.GenAI.Phi/Phi2/Config.cs | 2 +- src/Microsoft.ML.GenAI.Phi/Phi3/Config.cs | 8 +- src/Microsoft.ML.GenAI.Phi/Utils.cs | 60 ---- ...emplateFromAutoGenChatHistory.approved.txt | 1 + ...emplateFromAutoGenChatHistory.received.txt | 1 + ...al_7B_Instruct_V0_3_ShapeTest.approved.txt | 291 ++++++++++++++++++ ...truct_V0_3Tests.TokenizerTest.approved.txt | 2 + ...ts.cs => Mistral_7B_Instruct_V0_3Tests.cs} | 43 ++- 23 files changed, 1042 insertions(+), 105 deletions(-) create mode 100644 docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs delete mode 100644 src/Microsoft.ML.GenAI.LLaMA/Utils.cs create mode 100644 src/Microsoft.ML.GenAI.Mistral/MistralCausalLMAgent.cs create mode 100644 src/Microsoft.ML.GenAI.Mistral/MistralDecoderLayer.cs create mode 100644 src/Microsoft.ML.GenAI.Mistral/MistralForCausalLM.cs create mode 100644 src/Microsoft.ML.GenAI.Mistral/MistralModel.cs create mode 100644 src/Microsoft.ML.GenAI.Mistral/Mistral_7B_0_3ChatTemplateBuilder.cs create mode 100644 src/Microsoft.ML.GenAI.Mistral/Resource/Config/mistral-7B-instruct-v0.3.json create mode 100644 test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateFromAutoGenChatHistory.approved.txt create mode 100644 test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateFromAutoGenChatHistory.received.txt create mode 100644 test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.Mistral_7B_Instruct_V0_3_ShapeTest.approved.txt create mode 100644 test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.TokenizerTest.approved.txt rename test/Microsoft.ML.GenAI.Mistral.Tests/{Mistral_V0_3Tests.cs => Mistral_7B_Instruct_V0_3Tests.cs} (64%) diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj b/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj index d9932106d6..fab7d39b57 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj @@ -10,6 +10,7 @@ + diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs new file mode 100644 index 0000000000..e91c608279 --- /dev/null +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs @@ -0,0 +1,47 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using AutoGen.Core; +using Microsoft.ML.GenAI.Core; +using Microsoft.ML.GenAI.Mistral; +using Microsoft.ML.Tokenizers; +using TorchSharp; +using static TorchSharp.torch; + +namespace Microsoft.ML.GenAI.Samples.Mistral; + +internal class Mistral_7B_Instruct +{ + public static async void Run() + { + var device = "cuda"; + if (device == "cuda") + { + torch.InitializeDeviceType(DeviceType.CUDA); + } + + var defaultType = ScalarType.BFloat16; + torch.manual_seed(1); + torch.set_default_dtype(defaultType); + var weightFolder = @"C:\Users\xiaoyuz\source\repos\Mistral-7B-Instruct-v0.3"; + var configName = "config.json"; + var originalWeightFolder = Path.Combine(weightFolder); + + Console.WriteLine("Loading Mistral from huggingface model weight folder"); + var tokenizer = MistralTokenizerHelper.FromPretrained(originalWeightFolder); + var model = MistralForCausalLM.FromPretrained(weightFolder, configName, layersOnTargetDevice: -1); + + var pipeline = new CausalLMPipeline(tokenizer, model, device); + + var agent = new MistralCausalLMAgent(pipeline, "assistant") + .RegisterPrintMessage(); + + var task = """ + Write a C# program to print the sum of two numbers. Use top-level statement, put code between ```csharp and ```. + """; + + await agent.SendAsync(task); + } +} diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs index 5e4355e595..2ebfb25d38 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs @@ -1,4 +1,5 @@ // See https://aka.ms/new-console-template for more information +using Microsoft.ML.GenAI.Samples.Mistral; using Microsoft.ML.GenAI.Samples.Phi3Mini; -await AutoGenSample.RunAsync(); +Mistral_7B_Instruct.Run(); diff --git a/src/Microsoft.ML.GenAI.Core/Utility/IChatTemplateBuilder.cs b/src/Microsoft.ML.GenAI.Core/Utility/IChatTemplateBuilder.cs index a0720694c3..4cf5a00abf 100644 --- a/src/Microsoft.ML.GenAI.Core/Utility/IChatTemplateBuilder.cs +++ b/src/Microsoft.ML.GenAI.Core/Utility/IChatTemplateBuilder.cs @@ -19,7 +19,7 @@ public interface ISemanticKernelChatTemplateBuilder public interface IAutoGenChatTemplateBuilder { - string BuildPrompt(IEnumerable messages); + string BuildPrompt(IEnumerable messages, IEnumerable? tools = null); } public interface IChatTemplateBuilder : IAutoGenChatTemplateBuilder, ISemanticKernelChatTemplateBuilder diff --git a/src/Microsoft.ML.GenAI.Core/Utils.cs b/src/Microsoft.ML.GenAI.Core/Utils.cs index e4e1078d2e..dccabad653 100644 --- a/src/Microsoft.ML.GenAI.Core/Utils.cs +++ b/src/Microsoft.ML.GenAI.Core/Utils.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Reflection; using System.Text; using System.Threading.Tasks; using TorchSharp; @@ -161,4 +162,18 @@ public static Tensor RepeatKV(Tensor x, int nRep) .reshape(batchSize, nKVHeads * nRep, seqLen, headDim); } + internal static string GetEmbeddedResource(string resourceName) + { + // read file content from embedded resource + var assembly = Assembly.GetCallingAssembly(); + var resourceStream = assembly.GetManifestResourceStream(resourceName); + + if (resourceStream == null) + { + throw new ArgumentException("Resource not found", resourceName); + } + + using var reader = new System.IO.StreamReader(resourceStream); + return reader.ReadToEnd(); + } } diff --git a/src/Microsoft.ML.GenAI.LLaMA/Llama3_1ChatTemplateBuilder.cs b/src/Microsoft.ML.GenAI.LLaMA/Llama3_1ChatTemplateBuilder.cs index b96dee6dba..29e7fb1da2 100644 --- a/src/Microsoft.ML.GenAI.LLaMA/Llama3_1ChatTemplateBuilder.cs +++ b/src/Microsoft.ML.GenAI.LLaMA/Llama3_1ChatTemplateBuilder.cs @@ -15,7 +15,7 @@ public class Llama3_1ChatTemplateBuilder : IChatTemplateBuilder { private const char Newline = '\n'; - public string BuildPrompt(IEnumerable messages) + public string BuildPrompt(IEnumerable messages, IEnumerable? tools = null) { var availableRoles = new[] { Role.System, Role.User, Role.Assistant }; if (messages.Any(m => m.GetContent() is null)) diff --git a/src/Microsoft.ML.GenAI.LLaMA/Module/LlamaModel.cs b/src/Microsoft.ML.GenAI.LLaMA/Module/LlamaModel.cs index 1ba7820a9f..ec65128332 100644 --- a/src/Microsoft.ML.GenAI.LLaMA/Module/LlamaModel.cs +++ b/src/Microsoft.ML.GenAI.LLaMA/Module/LlamaModel.cs @@ -2,13 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; using Microsoft.ML.GenAI.Core; -using Microsoft.ML.GenAI.Core.Extension; using TorchSharp; using TorchSharp.Modules; using static TorchSharp.torch; diff --git a/src/Microsoft.ML.GenAI.LLaMA/Utils.cs b/src/Microsoft.ML.GenAI.LLaMA/Utils.cs deleted file mode 100644 index 622aba9fff..0000000000 --- a/src/Microsoft.ML.GenAI.LLaMA/Utils.cs +++ /dev/null @@ -1,27 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System.Reflection; -using TorchSharp; -using static TorchSharp.torch; - -namespace Microsoft.ML.GenAI.LLaMA; - -internal static class Utils -{ - public static string GetEmbeddedResource(string resourceName) - { - // read file content from embedded resource - var assembly = Assembly.GetExecutingAssembly(); - var resourceStream = assembly.GetManifestResourceStream(resourceName); - - if (resourceStream == null) - { - throw new ArgumentException("Resource not found", resourceName); - } - - using var reader = new System.IO.StreamReader(resourceStream); - return reader.ReadToEnd(); - } -} diff --git a/src/Microsoft.ML.GenAI.Mistral/MistralCausalLMAgent.cs b/src/Microsoft.ML.GenAI.Mistral/MistralCausalLMAgent.cs new file mode 100644 index 0000000000..1f25d31b09 --- /dev/null +++ b/src/Microsoft.ML.GenAI.Mistral/MistralCausalLMAgent.cs @@ -0,0 +1,89 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Runtime.CompilerServices; +using AutoGen.Core; +using Microsoft.ML.GenAI.Core; +using Microsoft.ML.Tokenizers; + +namespace Microsoft.ML.GenAI.Mistral; + +public class MistralCausalLMAgent : IStreamingAgent +{ + private readonly ICausalLMPipeline _pipeline; + private readonly string? _systemMessage; + private readonly IAutoGenChatTemplateBuilder _templateBuilder; + private readonly string _stopSequence = ""; + + /// + /// Create a new instance of . + /// + /// pipeline + /// agent name + /// system message. + /// the template builder to build chat prompt. If the value is null, would be used. + public MistralCausalLMAgent( + ICausalLMPipeline pipeline, + string name, + string? systemMessage = "you are a helpful assistant", + IAutoGenChatTemplateBuilder? templateBuilder = null) + { + this.Name = name; + this._pipeline = pipeline; + this._systemMessage = systemMessage; + this._templateBuilder = templateBuilder ?? Mistral_7B_0_3ChatTemplateBuilder.Instance; + } + + public string Name { get; } + + public Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) + { + if (_systemMessage != null) + { + var systemMessage = new TextMessage(Role.System, _systemMessage, from: this.Name); + messages = messages.Prepend(systemMessage); + } + var input = _templateBuilder.BuildPrompt(messages); + var maxLen = options?.MaxToken ?? 1024; + var temperature = options?.Temperature ?? 0.7f; + var stopTokenSequence = options?.StopSequence ?? []; + stopTokenSequence = stopTokenSequence.Append(_stopSequence).ToArray(); + + var output = _pipeline.Generate( + input, + maxLen: maxLen, + temperature: temperature, + stopSequences: stopTokenSequence) ?? throw new InvalidOperationException("Failed to generate a reply."); + + return Task.FromResult(new TextMessage(Role.Assistant, output, from: this.Name)); + } + +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + public async IAsyncEnumerable GenerateStreamingReplyAsync( +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously + IEnumerable messages, + GenerateReplyOptions? options = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + if (_systemMessage != null) + { + var systemMessage = new TextMessage(Role.System, _systemMessage, from: this.Name); + messages = messages.Prepend(systemMessage); + } + var input = _templateBuilder.BuildPrompt(messages); + var maxLen = options?.MaxToken ?? 1024; + var temperature = options?.Temperature ?? 0.7f; + var stopTokenSequence = options?.StopSequence ?? []; + stopTokenSequence = stopTokenSequence.Append(_stopSequence).ToArray(); + + foreach (var output in _pipeline.GenerateStreaming( + input, + maxLen: maxLen, + temperature: temperature, + stopSequences: stopTokenSequence)) + { + yield return new TextMessageUpdate(Role.Assistant, output, from: this.Name); + } + } +} diff --git a/src/Microsoft.ML.GenAI.Mistral/MistralConfig.cs b/src/Microsoft.ML.GenAI.Mistral/MistralConfig.cs index a1a66a5585..c2240f9579 100644 --- a/src/Microsoft.ML.GenAI.Mistral/MistralConfig.cs +++ b/src/Microsoft.ML.GenAI.Mistral/MistralConfig.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Text.Json; using System.Text.Json.Serialization; using Microsoft.ML.GenAI.Core; using TorchSharp; @@ -23,7 +24,6 @@ public MistralConfig() this.NumAttentionHeads = 32; this.NumHiddenLayers = 32; this.NumKeyValueHeads = 8; - this.PretrainingTp = 1; this.RmsNormEps = 1e-05f; this.RopeScaling = new RopeScalingConfig(); this.RopeTheta = 500000.0; @@ -31,8 +31,26 @@ public MistralConfig() this.VocabSize = 128256; this.AttnImplementation = "eager"; this.DType = torch.ScalarType.BFloat16; + this.HeadDim = this.HiddenSize / this.NumAttentionHeads; + this.SlidingWindow ??= 4096; } + static MistralConfig() + { +#pragma warning disable MSML_ParameterLocalVarName // Parameter or local variable name not standard + var mistral7BInstructContent = Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Mistral.Resource.Config.mistral-7B-instruct-v0.3.json"); +#pragma warning restore MSML_ParameterLocalVarName // Parameter or local variable name not standard + + Mistral_7B_Instruct_v0_3 = JsonSerializer.Deserialize(mistral7BInstructContent) ?? throw new ArgumentNullException(nameof(mistral7BInstructContent)); + } + +#pragma warning disable MSML_GeneralName // This name should be PascalCased + /// + /// The mistral-7b-instruct-v0.3 configuration created from https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3/tree/main. + /// + public static MistralConfig Mistral_7B_Instruct_v0_3 { get; } +#pragma warning restore MSML_GeneralName // This name should be PascalCased + [JsonPropertyName("attention_bias")] public bool AttentionBias { get; set; } @@ -66,8 +84,8 @@ public MistralConfig() [JsonPropertyName("num_key_value_heads")] public int NumKeyValueHeads { get; set; } - [JsonPropertyName("pretraining_tp")] - public int PretrainingTp { get; set; } + [JsonPropertyName("head_dim")] + public int HeadDim { get; set; } [JsonPropertyName("rms_norm_eps")] public float RmsNormEps { get; set; } @@ -82,7 +100,13 @@ public MistralConfig() [JsonPropertyName("vocab_size")] public int VocabSize { get; set; } + + [JsonPropertyName("sliding_window")] + public int? SlidingWindow { get; set; } + public int? PadTokenId { get; set; } + public torch.ScalarType DType { get; set; } + public string AttnImplementation { get; set; } } diff --git a/src/Microsoft.ML.GenAI.Mistral/MistralDecoderLayer.cs b/src/Microsoft.ML.GenAI.Mistral/MistralDecoderLayer.cs new file mode 100644 index 0000000000..7f17991b5c --- /dev/null +++ b/src/Microsoft.ML.GenAI.Mistral/MistralDecoderLayer.cs @@ -0,0 +1,148 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.GenAI.Core; +using static TorchSharp.torch; + +namespace Microsoft.ML.GenAI.Mistral.Module; + +internal class DecoderLayerInput +{ + public DecoderLayerInput( + Tensor hiddenStates, + Tensor attentionMask, + Tensor positionIds, + RotaryEmbeddingOutput positionEmbeddings, // cos, sin + IKVCache? pastKeyValue = null, + bool outputAttentions = false) + { + this.HiddenStates = hiddenStates; + this.AttentionMask = attentionMask; + this.PositionIds = positionIds; + this.PastKeyValue = pastKeyValue; + this.OutputAttentions = outputAttentions; + this.PositionalEmbeddings = positionEmbeddings; + } + + public Tensor HiddenStates { get; set; } + + public Tensor AttentionMask { get; set; } + + public Tensor PositionIds { get; set; } + + public RotaryEmbeddingOutput PositionalEmbeddings { get; set; } + + public IKVCache? PastKeyValue { get; set; } + + public bool OutputAttentions { get; set; } +} + +internal class DecoderLayerOutput +{ + public DecoderLayerOutput( + Tensor hiddenStates, + Tensor? attentions = null, + IKVCache? pastKeyValue = null) + { + this.HiddenStates = hiddenStates; + this.Attentions = attentions; + this.PastKeyValue = pastKeyValue; + } + + public Tensor HiddenStates { get; set; } + + public Tensor? Attentions { get; set; } + + public IKVCache? PastKeyValue { get; set; } +} +internal class MistralDecoderLayer : nn.Module, IDynamicLoadModule +{ + private readonly MistralConfig _llamaConfig; + private readonly int _layerIndex; + private readonly int _hiddenSize; + +#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format + private readonly MistralMLP mlp; + private readonly Core.RMSNorm input_layernorm; + private readonly Core.RMSNorm post_attention_layernorm; + private readonly Attention self_attn; + + public Action? LoadToDeviceFunc { get; set; } + public Action? UnloadFromDeviceFunc { get; set; } + +#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format + + public MistralDecoderLayer(MistralConfig config, int layerIndex) + : base(nameof(MistralDecoderLayer)) + { + _llamaConfig = config; + _layerIndex = layerIndex; + _hiddenSize = config.HiddenSize; + + this.self_attn = CreateAttention(config, layerIndex); + this.mlp = new MistralMLP(config); + this.input_layernorm = new Core.RMSNorm(this._hiddenSize, eps: config.RmsNormEps, config.DType); + this.post_attention_layernorm = new Core.RMSNorm(this._hiddenSize, eps: config.RmsNormEps, config.DType); + } + + private Attention CreateAttention(MistralConfig config, int layerIndex) + { + var headDim = config.HiddenSize / config.NumAttentionHeads; + return new Attention( + attentionDropout: config.AttentionDropout, + hiddenSize: config.HiddenSize, + numHeads: config.NumAttentionHeads, + headDim: headDim, + numKeyValueHeads: config.NumKeyValueHeads, + numKeyValueGroups: config.NumAttentionHeads / config.NumKeyValueHeads, + maxPositionEmbeddings: config.MaxPositionEmbeddings, + originalMaxPositionEmbeddings: config.MaxPositionEmbeddings, + layerIdx: layerIndex, + useQkvProj: false, + dtype: config.DType, + attentionBias: config.AttentionBias); + } + +#pragma warning disable MSML_GeneralName // This name should be PascalCased + public override DecoderLayerOutput forward(DecoderLayerInput input) +#pragma warning restore MSML_GeneralName // This name should be PascalCased + { + if (LoadToDeviceFunc != null) + { + LoadToDeviceFunc(this); + } + + using var disposeScope = NewDisposeScope(); + var residual = input.HiddenStates; + var hiddenStates = this.input_layernorm.forward(input.HiddenStates); + + var selfAttnInput = new AttentionInput( + hiddenStates: hiddenStates, + attentionMask: input.AttentionMask, + positionIds: input.PositionIds, + cache: input.PastKeyValue, + positionalEmbeddings: input.PositionalEmbeddings, + outputAttentions: input.OutputAttentions); + + var selfAttnOutput = this.self_attn.forward(selfAttnInput); + + hiddenStates = residual + selfAttnOutput.HiddenStates; + + // Fully connected + residual = hiddenStates; + hiddenStates = this.post_attention_layernorm.forward(hiddenStates); + hiddenStates = this.mlp.forward(hiddenStates); + hiddenStates = residual + hiddenStates; + + if (UnloadFromDeviceFunc != null) + { + UnloadFromDeviceFunc(this); + } + + return new DecoderLayerOutput( + hiddenStates: hiddenStates.MoveToOuterDisposeScope(), + attentions: input.OutputAttentions ? selfAttnOutput.Attentions?.MoveToOuterDisposeScope() : null, + pastKeyValue: selfAttnOutput.Cache); + } +} diff --git a/src/Microsoft.ML.GenAI.Mistral/MistralForCausalLM.cs b/src/Microsoft.ML.GenAI.Mistral/MistralForCausalLM.cs new file mode 100644 index 0000000000..3d3ac11bbd --- /dev/null +++ b/src/Microsoft.ML.GenAI.Mistral/MistralForCausalLM.cs @@ -0,0 +1,121 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Diagnostics; +using System.Text.Json; +using Microsoft.ML.GenAI.Core; +using Microsoft.ML.GenAI.Core.Extension; +using Microsoft.ML.GenAI.Mistral.Module; +using TorchSharp; +using TorchSharp.PyBridge; +using static TorchSharp.torch; + +namespace Microsoft.ML.GenAI.Mistral; + +public class MistralForCausalLM : nn.Module +{ + private readonly MistralConfig _config; + private readonly int _vocabSize; + +#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format + private readonly GenAILinear lm_head; + private readonly MistralModel model; +#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format + + public MistralForCausalLM(MistralConfig config) + : base(nameof(MistralForCausalLM)) + { + _config = config; + _vocabSize = config.VocabSize; + + model = new MistralModel(config); + lm_head = new GenAILinear(config.HiddenSize, config.VocabSize, hasBias: false); + + this.RegisterComponents(); + } + +#pragma warning disable MSML_GeneralName // This name should be PascalCased + public override CausalLMModelOutput forward(CausalLMModelInput input) +#pragma warning restore MSML_GeneralName // This name should be PascalCased + { + var outputs = this.model.forward(input); + var logits = this.lm_head.forward(outputs.LastHiddenState); + logits = logits.to_type(ScalarType.Float32); + outputs.Logits = logits; + + return outputs; + } + + public static MistralForCausalLM FromPretrained( + string modelFolder, + string configName = "config.json", + string checkPointName = "model.safetensors.index.json", + ScalarType torchDtype = ScalarType.BFloat16, + string device = "cpu") + { + var config = Path.Join(modelFolder, configName); + var modelConfig = JsonSerializer.Deserialize(File.ReadAllText(config)) ?? throw new ArgumentNullException(nameof(config)); + modelConfig.DType = torchDtype; + var model = new MistralForCausalLM(modelConfig); + + model.LoadSafeTensors(modelFolder, checkPointName); + model = model.to(device); + + return model; + } + + public static MistralForCausalLM FromPretrained( + string modelFolder, + string configName = "config.json", + string checkPointName = "model.safetensors.index.json", + bool quantizeToInt8 = false, + bool quantizeToInt4 = false, + int layersOnTargetDevice = -1, + ScalarType torchDtype = ScalarType.BFloat16, + string targetDevice = "cuda") + { + if (layersOnTargetDevice == -1 && quantizeToInt4 == false && quantizeToInt8 == false) + { + return FromPretrained(modelFolder, configName, checkPointName, torchDtype, targetDevice); + } + + var originalDefaultDevice = torch.get_default_device(); + torch.set_default_device("meta"); + var config = Path.Join(modelFolder, configName); + var modelConfig = JsonSerializer.Deserialize(File.ReadAllText(config)) ?? throw new ArgumentNullException(nameof(config)); + modelConfig.DType = torchDtype; + var model = new MistralForCausalLM(modelConfig); + + if (quantizeToInt8) + { + model.ToInt8QuantizeModule(); + } + else if (quantizeToInt4) + { + model.ToInt4QuantizeModule(); + } + + var deviceMap = model.InferDeviceMapForEachLayer( + [ + KeyValuePair.Create(targetDevice, layersOnTargetDevice), + KeyValuePair.Create("cpu", -1) + ]); + + torch.set_default_device("cpu"); + model = new MistralForCausalLM(modelConfig); + + model.LoadSafeTensors(modelFolder, checkPointName); + + model = model.ToDynamicLoadingModel(deviceMap, targetDevice); + + torch.set_default_device(originalDefaultDevice); + + return model; + } + + public void LoadSafeTensors(string modelFolder, string checkPointName = "model.safetensors.index.json") + { + this.load_checkpoint(path: modelFolder, checkpointName: checkPointName, strict: true, useTqdm: false); + } +} diff --git a/src/Microsoft.ML.GenAI.Mistral/MistralModel.cs b/src/Microsoft.ML.GenAI.Mistral/MistralModel.cs new file mode 100644 index 0000000000..6c05fe53e9 --- /dev/null +++ b/src/Microsoft.ML.GenAI.Mistral/MistralModel.cs @@ -0,0 +1,148 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.GenAI.Core; +using TorchSharp; +using TorchSharp.Modules; +using static TorchSharp.torch; + +namespace Microsoft.ML.GenAI.Mistral.Module; + +internal class MistralModel : nn.Module +{ + private readonly MistralConfig _config; + private readonly int? _paddingIdx; + private readonly int _vocabSize; + private IKVCache _cache; +#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format + private readonly Embedding embed_tokens; + private readonly ModuleList layers; + private readonly RMSNorm norm; +#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format + private readonly nn.Module _rotaryEmb; + + + public MistralModel(MistralConfig config) + : base(nameof(MistralModel)) + { + this._config = config; + this._paddingIdx = config.PadTokenId; + this._vocabSize = config.VocabSize; + var headDim = config.HeadDim; + this.embed_tokens = nn.Embedding(config.VocabSize, config.HiddenSize, padding_idx: this._paddingIdx, dtype: config.DType); + this.layers = new ModuleList(); + + for (int i = 0; i < config.NumHiddenLayers; i++) + { + this.layers.Add(new MistralDecoderLayer(config, i)); + } + this.norm = new RMSNorm(config.HiddenSize, config.RmsNormEps, config.DType); + this._cache = new DynamicKVCache(); + this.RegisterComponents(); + this._rotaryEmb = config.RopeScaling switch + { + null => new RotaryEmbedding(config.RopeTheta, config.MaxPositionEmbeddings, headDim), + _ => new RotaryEmbedding(config.RopeTheta, headDim, config.RopeScaling), + }; + } + +#pragma warning disable MSML_GeneralName // This name should be PascalCased + public override CausalLMModelOutput forward(CausalLMModelInput input) +#pragma warning restore MSML_GeneralName // This name should be PascalCased + { + if (input.OverrideCache is not null) + { + this._cache = input.OverrideCache; + } + + var outputAttentions = input.OutputAttentions; + var outputHiddenStates = input.OutputHiddenStates; + var attentionMask = input.AttentionMask; + Device device; + var inputIds = input.InputIds; + var positionIds = input.PositionIds; + var inputsEmbeds = input.InputEmbeddings; + int batchSize; + int seqLength; + if (inputIds is not null && inputsEmbeds is not null) + { + throw new ArgumentException("Only one of input_ids or inputs_embeds may be set"); + } + else if (inputIds is not null) + { + batchSize = inputIds.IntShape()[0]; + seqLength = inputIds.IntShape()[1]; + inputsEmbeds = this.embed_tokens.forward(inputIds); + device = inputIds.device; + } + else if (inputsEmbeds is not null) + { + batchSize = inputsEmbeds.IntShape()[0]; + seqLength = inputsEmbeds.IntShape()[1]; + device = inputsEmbeds.device; + } + else + { + throw new ArgumentException("Either input_ids or inputs_embeds must be set"); + } + + var pastKeyValuesLength = input.PastKeyValuesLength; + + if (positionIds is null) + { + positionIds = torch.arange(pastKeyValuesLength, seqLength + pastKeyValuesLength, device: device); + positionIds = positionIds.unsqueeze(0).view(-1, seqLength); + } + else + { + positionIds = ((long)positionIds.view(-1, seqLength)); + } + + if (this._config.AttnImplementation == "flash_attention_2") + { + throw new NotImplementedException(); + } + else + { + // the following behavior of creating 4d causal mask doesn't match python's, remember to look into it when there's time. + attentionMask = AttentionMaskConverter.Create4DCausalAttentionMask(attentionMask, [batchSize, seqLength], inputsEmbeds.dtype, device, pastKeyValuesLength, slidingWindow: _config.SlidingWindow); + } + + var hiddenStates = inputsEmbeds; + + var allHiddenStates = new List(); + var allAttentions = new List(); + + var embOutput = this._rotaryEmb.forward(new RotaryEmbeddingInput(hiddenStates, positionIds, pastKeyValuesLength)); + foreach (var layer in this.layers) + { + if (outputHiddenStates) + { + allHiddenStates.Add(hiddenStates); + } + + var decoderInput = new DecoderLayerInput( + hiddenStates: hiddenStates, + attentionMask: attentionMask!, + positionIds: positionIds, + pastKeyValue: this._cache, + positionEmbeddings: embOutput, + outputAttentions: outputAttentions); + var layerOutput = layer.forward(decoderInput); + hiddenStates = layerOutput.HiddenStates; + if (outputAttentions && layerOutput.Attentions is not null) + { + allAttentions.Add(layerOutput.Attentions); + } + } + + hiddenStates = this.norm.forward(hiddenStates); + if (outputHiddenStates) + { + allHiddenStates.Add(hiddenStates); + } + + return new CausalLMModelOutput(lastHiddenState: hiddenStates, allHiddenStates: allHiddenStates.ToArray(), attentions: allAttentions.ToArray(), cache: this._cache); + } +} diff --git a/src/Microsoft.ML.GenAI.Mistral/Mistral_7B_0_3ChatTemplateBuilder.cs b/src/Microsoft.ML.GenAI.Mistral/Mistral_7B_0_3ChatTemplateBuilder.cs new file mode 100644 index 0000000000..25df3390be --- /dev/null +++ b/src/Microsoft.ML.GenAI.Mistral/Mistral_7B_0_3ChatTemplateBuilder.cs @@ -0,0 +1,79 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Text; +using AutoGen.Core; +using Microsoft.ML.GenAI.Core; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace Microsoft.ML.GenAI.Mistral; + +/// +/// the chat template builder for Mistral 7B v0.3 +/// +#pragma warning disable MSML_GeneralName // This name should be PascalCased +public class Mistral_7B_0_3ChatTemplateBuilder : IChatTemplateBuilder +#pragma warning restore MSML_GeneralName // This name should be PascalCased +{ + private const char Newline = '\n'; + + public static Mistral_7B_0_3ChatTemplateBuilder Instance { get; } = new Mistral_7B_0_3ChatTemplateBuilder(); + + public string BuildPrompt(IEnumerable messages, IEnumerable? tools = null) + { + // can only contain at most one system message + if (messages.Where(m => m.GetRole() == Role.System).Count() > 1) + { + throw new InvalidOperationException("Please provide at most one system message."); + } + + var systemMessage = messages.FirstOrDefault(m => m.GetRole() == Role.System)?.GetContent(); + + // split the messages into two sequences by the last user message + // e.g [user, assistant, user, assistant, user] -> [[user, assistant, user, assistant], [user]] + + var firstSequence = messages.Take(messages.ToList().FindLastIndex(m => m.GetRole() == Role.User) + 1); + var secondSequence = messages.Skip(messages.ToList().FindLastIndex(m => m.GetRole() == Role.User) + 1); + + var sb = new StringBuilder(); + sb.Append(""); + foreach (var message in firstSequence) + { + // skip system + if (message.GetRole() == Role.System) + { + continue; + } + + var role = message.GetRole()!.Value; + var content = message.GetContent()!; + sb.Append(message switch + { + _ when message.GetRole() == Role.User => $"[INST]{content.Trim()}[/INST]", + _ when message.GetRole() == Role.Assistant => $"{content.Trim()}", + _ => throw new InvalidOperationException("Invalid role.") + }); + } + + foreach (var message in secondSequence) + { + var role = message.GetRole()!.Value; + var content = message.GetContent()!; + sb.Append(message switch + { + _ when message.GetRole() == Role.User && !string.IsNullOrEmpty(systemMessage) => $"[INST] {systemMessage} {Newline}{Newline}{content.Trim()}[/INST]", + _ when message.GetRole() == Role.User => $"[INST]{content.Trim()}[/INST]", + _ when message.GetRole() == Role.Assistant => $"{content.Trim()}", + _ => throw new InvalidOperationException("Invalid role.") + }); + } + + return sb.ToString(); + } + + public string BuildPrompt(ChatHistory chatHistory) + { + throw new NotImplementedException(); + } +} diff --git a/src/Microsoft.ML.GenAI.Mistral/Resource/Config/mistral-7B-instruct-v0.3.json b/src/Microsoft.ML.GenAI.Mistral/Resource/Config/mistral-7B-instruct-v0.3.json new file mode 100644 index 0000000000..1da2dde41f --- /dev/null +++ b/src/Microsoft.ML.GenAI.Mistral/Resource/Config/mistral-7B-instruct-v0.3.json @@ -0,0 +1,21 @@ +{ + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 32768, + "model_type": "mistral", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-05, + "rope_theta": 1000000.0, + "sliding_window": null, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "use_cache": true, + "vocab_size": 32768 +} diff --git a/src/Microsoft.ML.GenAI.Phi/Phi2/Config.cs b/src/Microsoft.ML.GenAI.Phi/Phi2/Config.cs index fdba74ba77..580bde9b12 100644 --- a/src/Microsoft.ML.GenAI.Phi/Phi2/Config.cs +++ b/src/Microsoft.ML.GenAI.Phi/Phi2/Config.cs @@ -41,7 +41,7 @@ public Phi2Config() static Phi2Config() { - var phi2ConfigContent = Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-2-config.json"); + var phi2ConfigContent = Core.Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-2-config.json"); var phi2Config = JsonSerializer.Deserialize(phi2ConfigContent) ?? throw new ArgumentNullException(nameof(phi2ConfigContent)); Phi2 = phi2Config; } diff --git a/src/Microsoft.ML.GenAI.Phi/Phi3/Config.cs b/src/Microsoft.ML.GenAI.Phi/Phi3/Config.cs index def5ab3448..0a020d6724 100644 --- a/src/Microsoft.ML.GenAI.Phi/Phi3/Config.cs +++ b/src/Microsoft.ML.GenAI.Phi/Phi3/Config.cs @@ -42,10 +42,10 @@ public Phi3Config() static Phi3Config() { - var phi3Mini4kInstructContent = Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-3-mini-4k-instruct-config.json"); - var phi3Mini128kInstructContent = Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-3-mini-128k-instruct-config.json"); - var phi3Medium4kInstructContent = Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-3-medium-4k-instruct-config.json"); - var phi3Medium128kInstructContent = Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-3-medium-128k-instruct-config.json"); + var phi3Mini4kInstructContent = Core.Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-3-mini-4k-instruct-config.json"); + var phi3Mini128kInstructContent = Core.Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-3-mini-128k-instruct-config.json"); + var phi3Medium4kInstructContent = Core.Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-3-medium-4k-instruct-config.json"); + var phi3Medium128kInstructContent = Core.Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-3-medium-128k-instruct-config.json"); Phi3Mini4kInstruct = JsonSerializer.Deserialize(phi3Mini4kInstructContent) ?? throw new ArgumentNullException(nameof(phi3Mini4kInstructContent)); Phi3Mini128kInstruct = JsonSerializer.Deserialize(phi3Mini128kInstructContent) ?? throw new ArgumentNullException(nameof(phi3Mini128kInstructContent)); diff --git a/src/Microsoft.ML.GenAI.Phi/Utils.cs b/src/Microsoft.ML.GenAI.Phi/Utils.cs index aa5a71719e..5be880a06e 100644 --- a/src/Microsoft.ML.GenAI.Phi/Utils.cs +++ b/src/Microsoft.ML.GenAI.Phi/Utils.cs @@ -16,49 +16,6 @@ namespace Microsoft.ML.GenAI.Phi; internal static class Utils { - public static string GetEmbeddedResource(string resourceName) - { - // read file content from embedded resource - var assembly = Assembly.GetExecutingAssembly(); - var resourceStream = assembly.GetManifestResourceStream(resourceName); - - if (resourceStream == null) - { - throw new ArgumentException("Resource not found", nameof(resourceName)); - } - - using var reader = new System.IO.StreamReader(resourceStream); - return reader.ReadToEnd(); - } - - public static Tensor ApplyRotaryEmbeddings(Tensor input, Tensor freqsComplex) - { - // Separate the last dimension pairs of two values, representing the real and imaginary parts of the complex number - // Two consecutive values will become a single complex number - // (B, Seq_Len, H, Head_Dim) -> (B, Seq_Len, H, Head_Dim/2) - var inputComplex = input.to_type(ScalarType.Float32).reshape(input.shape[0], input.shape[1], input.shape[2], -1, 2).view_as_complex(); - freqsComplex = freqsComplex.to(input.device); - - // Reshape the freqs_complex tensor to match the shape of the x_complex tensor. So we need to add the batch dimension and the head dimension - // (Seq_Len, Head_Dim/2) --> (1, Seq_Len, 1, Head_Dim/2) - var freqsComplexReshaped = freqsComplex.unsqueeze(0).unsqueeze(2); - - // Multiply each complex number in the x_complex tensor by the corresponding complex number in the freqs_complex tensor - // Which results in the rotation of the complex number as shown in the Figure 1 of the paper - // (B, Seq_Len, H, Head_Dim/2) * (1, Seq_Len, 1, Head_Dim/2) = (B, Seq_Len, H, Head_Dim/2) - var rotatedComplex = inputComplex * freqsComplexReshaped; - // Console.WriteLine(rotated_complex.mean().ToSingle()); - - // Convert the complex number back to the real number - // (B, Seq_Len, H, Head_Dim/2) -> (B, Seq_Len, H, Head_Dim/2, 2) - var rotated = rotatedComplex.view_as_real(); - - // (B, Seq_Len, H, Head_Dim/2, 2) -> (B, Seq_Len, H, Head_Dim) - var rotatedReshaped = rotated.reshape(rotated.shape[0], rotated.shape[1], rotated.shape[2], -1); - - return rotatedReshaped.type_as(input); - } - public static Tensor PrecomputeThetaPosFrequencies(int headDim, int seqLen, string device, float theta = 10000.0f) { // As written in the paragraph 3.2.2 of the paper @@ -147,21 +104,4 @@ public static Tensor Phi2RepeatKV(Tensor x, int nRep) .expand(batchSize, seqLen, nKVHeads, nRep, headDim) .view(batchSize, seqLen, nKVHeads * nRep, headDim); } - - public static Tensor Phi3RepeatKV(Tensor x, int nRep) - { - var batchSize = x.shape[0]; - var nKVHeads = x.shape[1]; - var seqLen = x.shape[2]; - var headDim = x.shape[3]; - if (nRep == 1) - { - return x; - } - - return x.unsqueeze(3) - .expand(batchSize, nKVHeads, nRep, seqLen, headDim) - .view(batchSize, nKVHeads * nRep, seqLen, headDim); - } - } diff --git a/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateFromAutoGenChatHistory.approved.txt b/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateFromAutoGenChatHistory.approved.txt new file mode 100644 index 0000000000..51ad32bf9f --- /dev/null +++ b/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateFromAutoGenChatHistory.approved.txt @@ -0,0 +1 @@ +[INST]Hello?[/INST]World! \ No newline at end of file diff --git a/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateFromAutoGenChatHistory.received.txt b/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateFromAutoGenChatHistory.received.txt new file mode 100644 index 0000000000..51ad32bf9f --- /dev/null +++ b/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateFromAutoGenChatHistory.received.txt @@ -0,0 +1 @@ +[INST]Hello?[/INST]World! \ No newline at end of file diff --git a/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.Mistral_7B_Instruct_V0_3_ShapeTest.approved.txt b/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.Mistral_7B_Instruct_V0_3_ShapeTest.approved.txt new file mode 100644 index 0000000000..4bad35f7d7 --- /dev/null +++ b/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.Mistral_7B_Instruct_V0_3_ShapeTest.approved.txt @@ -0,0 +1,291 @@ +0: lm_head.weight shape: [32768, 4096] +1: model.embed_tokens.weight shape: [32768, 4096] +2: model.layers.0.input_layernorm.weight shape: [4096] +3: model.layers.0.mlp.down_proj.weight shape: [4096, 14336] +4: model.layers.0.mlp.gate_proj.weight shape: [14336, 4096] +5: model.layers.0.mlp.up_proj.weight shape: [14336, 4096] +6: model.layers.0.post_attention_layernorm.weight shape: [4096] +7: model.layers.0.self_attn.k_proj.weight shape: [1024, 4096] +8: model.layers.0.self_attn.o_proj.weight shape: [4096, 4096] +9: model.layers.0.self_attn.q_proj.weight shape: [4096, 4096] +10: model.layers.0.self_attn.v_proj.weight shape: [1024, 4096] +11: model.layers.1.input_layernorm.weight shape: [4096] +12: model.layers.1.mlp.down_proj.weight shape: [4096, 14336] +13: model.layers.1.mlp.gate_proj.weight shape: [14336, 4096] +14: model.layers.1.mlp.up_proj.weight shape: [14336, 4096] +15: model.layers.1.post_attention_layernorm.weight shape: [4096] +16: model.layers.1.self_attn.k_proj.weight shape: [1024, 4096] +17: model.layers.1.self_attn.o_proj.weight shape: [4096, 4096] +18: model.layers.1.self_attn.q_proj.weight shape: [4096, 4096] +19: model.layers.1.self_attn.v_proj.weight shape: [1024, 4096] +20: model.layers.10.input_layernorm.weight shape: [4096] +21: model.layers.10.mlp.down_proj.weight shape: [4096, 14336] +22: model.layers.10.mlp.gate_proj.weight shape: [14336, 4096] +23: model.layers.10.mlp.up_proj.weight shape: [14336, 4096] +24: model.layers.10.post_attention_layernorm.weight shape: [4096] +25: model.layers.10.self_attn.k_proj.weight shape: [1024, 4096] +26: model.layers.10.self_attn.o_proj.weight shape: [4096, 4096] +27: model.layers.10.self_attn.q_proj.weight shape: [4096, 4096] +28: model.layers.10.self_attn.v_proj.weight shape: [1024, 4096] +29: model.layers.11.input_layernorm.weight shape: [4096] +30: model.layers.11.mlp.down_proj.weight shape: [4096, 14336] +31: model.layers.11.mlp.gate_proj.weight shape: [14336, 4096] +32: model.layers.11.mlp.up_proj.weight shape: [14336, 4096] +33: model.layers.11.post_attention_layernorm.weight shape: [4096] +34: model.layers.11.self_attn.k_proj.weight shape: [1024, 4096] +35: model.layers.11.self_attn.o_proj.weight shape: [4096, 4096] +36: model.layers.11.self_attn.q_proj.weight shape: [4096, 4096] +37: model.layers.11.self_attn.v_proj.weight shape: [1024, 4096] +38: model.layers.12.input_layernorm.weight shape: [4096] +39: model.layers.12.mlp.down_proj.weight shape: [4096, 14336] +40: model.layers.12.mlp.gate_proj.weight shape: [14336, 4096] +41: model.layers.12.mlp.up_proj.weight shape: [14336, 4096] +42: model.layers.12.post_attention_layernorm.weight shape: [4096] +43: model.layers.12.self_attn.k_proj.weight shape: [1024, 4096] +44: model.layers.12.self_attn.o_proj.weight shape: [4096, 4096] +45: model.layers.12.self_attn.q_proj.weight shape: [4096, 4096] +46: model.layers.12.self_attn.v_proj.weight shape: [1024, 4096] +47: model.layers.13.input_layernorm.weight shape: [4096] +48: model.layers.13.mlp.down_proj.weight shape: [4096, 14336] +49: model.layers.13.mlp.gate_proj.weight shape: [14336, 4096] +50: model.layers.13.mlp.up_proj.weight shape: [14336, 4096] +51: model.layers.13.post_attention_layernorm.weight shape: [4096] +52: model.layers.13.self_attn.k_proj.weight shape: [1024, 4096] +53: model.layers.13.self_attn.o_proj.weight shape: [4096, 4096] +54: model.layers.13.self_attn.q_proj.weight shape: [4096, 4096] +55: model.layers.13.self_attn.v_proj.weight shape: [1024, 4096] +56: model.layers.14.input_layernorm.weight shape: [4096] +57: model.layers.14.mlp.down_proj.weight shape: [4096, 14336] +58: model.layers.14.mlp.gate_proj.weight shape: [14336, 4096] +59: model.layers.14.mlp.up_proj.weight shape: [14336, 4096] +60: model.layers.14.post_attention_layernorm.weight shape: [4096] +61: model.layers.14.self_attn.k_proj.weight shape: [1024, 4096] +62: model.layers.14.self_attn.o_proj.weight shape: [4096, 4096] +63: model.layers.14.self_attn.q_proj.weight shape: [4096, 4096] +64: model.layers.14.self_attn.v_proj.weight shape: [1024, 4096] +65: model.layers.15.input_layernorm.weight shape: [4096] +66: model.layers.15.mlp.down_proj.weight shape: [4096, 14336] +67: model.layers.15.mlp.gate_proj.weight shape: [14336, 4096] +68: model.layers.15.mlp.up_proj.weight shape: [14336, 4096] +69: model.layers.15.post_attention_layernorm.weight shape: [4096] +70: model.layers.15.self_attn.k_proj.weight shape: [1024, 4096] +71: model.layers.15.self_attn.o_proj.weight shape: [4096, 4096] +72: model.layers.15.self_attn.q_proj.weight shape: [4096, 4096] +73: model.layers.15.self_attn.v_proj.weight shape: [1024, 4096] +74: model.layers.16.input_layernorm.weight shape: [4096] +75: model.layers.16.mlp.down_proj.weight shape: [4096, 14336] +76: model.layers.16.mlp.gate_proj.weight shape: [14336, 4096] +77: model.layers.16.mlp.up_proj.weight shape: [14336, 4096] +78: model.layers.16.post_attention_layernorm.weight shape: [4096] +79: model.layers.16.self_attn.k_proj.weight shape: [1024, 4096] +80: model.layers.16.self_attn.o_proj.weight shape: [4096, 4096] +81: model.layers.16.self_attn.q_proj.weight shape: [4096, 4096] +82: model.layers.16.self_attn.v_proj.weight shape: [1024, 4096] +83: model.layers.17.input_layernorm.weight shape: [4096] +84: model.layers.17.mlp.down_proj.weight shape: [4096, 14336] +85: model.layers.17.mlp.gate_proj.weight shape: [14336, 4096] +86: model.layers.17.mlp.up_proj.weight shape: [14336, 4096] +87: model.layers.17.post_attention_layernorm.weight shape: [4096] +88: model.layers.17.self_attn.k_proj.weight shape: [1024, 4096] +89: model.layers.17.self_attn.o_proj.weight shape: [4096, 4096] +90: model.layers.17.self_attn.q_proj.weight shape: [4096, 4096] +91: model.layers.17.self_attn.v_proj.weight shape: [1024, 4096] +92: model.layers.18.input_layernorm.weight shape: [4096] +93: model.layers.18.mlp.down_proj.weight shape: [4096, 14336] +94: model.layers.18.mlp.gate_proj.weight shape: [14336, 4096] +95: model.layers.18.mlp.up_proj.weight shape: [14336, 4096] +96: model.layers.18.post_attention_layernorm.weight shape: [4096] +97: model.layers.18.self_attn.k_proj.weight shape: [1024, 4096] +98: model.layers.18.self_attn.o_proj.weight shape: [4096, 4096] +99: model.layers.18.self_attn.q_proj.weight shape: [4096, 4096] +100: model.layers.18.self_attn.v_proj.weight shape: [1024, 4096] +101: model.layers.19.input_layernorm.weight shape: [4096] +102: model.layers.19.mlp.down_proj.weight shape: [4096, 14336] +103: model.layers.19.mlp.gate_proj.weight shape: [14336, 4096] +104: model.layers.19.mlp.up_proj.weight shape: [14336, 4096] +105: model.layers.19.post_attention_layernorm.weight shape: [4096] +106: model.layers.19.self_attn.k_proj.weight shape: [1024, 4096] +107: model.layers.19.self_attn.o_proj.weight shape: [4096, 4096] +108: model.layers.19.self_attn.q_proj.weight shape: [4096, 4096] +109: model.layers.19.self_attn.v_proj.weight shape: [1024, 4096] +110: model.layers.2.input_layernorm.weight shape: [4096] +111: model.layers.2.mlp.down_proj.weight shape: [4096, 14336] +112: model.layers.2.mlp.gate_proj.weight shape: [14336, 4096] +113: model.layers.2.mlp.up_proj.weight shape: [14336, 4096] +114: model.layers.2.post_attention_layernorm.weight shape: [4096] +115: model.layers.2.self_attn.k_proj.weight shape: [1024, 4096] +116: model.layers.2.self_attn.o_proj.weight shape: [4096, 4096] +117: model.layers.2.self_attn.q_proj.weight shape: [4096, 4096] +118: model.layers.2.self_attn.v_proj.weight shape: [1024, 4096] +119: model.layers.20.input_layernorm.weight shape: [4096] +120: model.layers.20.mlp.down_proj.weight shape: [4096, 14336] +121: model.layers.20.mlp.gate_proj.weight shape: [14336, 4096] +122: model.layers.20.mlp.up_proj.weight shape: [14336, 4096] +123: model.layers.20.post_attention_layernorm.weight shape: [4096] +124: model.layers.20.self_attn.k_proj.weight shape: [1024, 4096] +125: model.layers.20.self_attn.o_proj.weight shape: [4096, 4096] +126: model.layers.20.self_attn.q_proj.weight shape: [4096, 4096] +127: model.layers.20.self_attn.v_proj.weight shape: [1024, 4096] +128: model.layers.21.input_layernorm.weight shape: [4096] +129: model.layers.21.mlp.down_proj.weight shape: [4096, 14336] +130: model.layers.21.mlp.gate_proj.weight shape: [14336, 4096] +131: model.layers.21.mlp.up_proj.weight shape: [14336, 4096] +132: model.layers.21.post_attention_layernorm.weight shape: [4096] +133: model.layers.21.self_attn.k_proj.weight shape: [1024, 4096] +134: model.layers.21.self_attn.o_proj.weight shape: [4096, 4096] +135: model.layers.21.self_attn.q_proj.weight shape: [4096, 4096] +136: model.layers.21.self_attn.v_proj.weight shape: [1024, 4096] +137: model.layers.22.input_layernorm.weight shape: [4096] +138: model.layers.22.mlp.down_proj.weight shape: [4096, 14336] +139: model.layers.22.mlp.gate_proj.weight shape: [14336, 4096] +140: model.layers.22.mlp.up_proj.weight shape: [14336, 4096] +141: model.layers.22.post_attention_layernorm.weight shape: [4096] +142: model.layers.22.self_attn.k_proj.weight shape: [1024, 4096] +143: model.layers.22.self_attn.o_proj.weight shape: [4096, 4096] +144: model.layers.22.self_attn.q_proj.weight shape: [4096, 4096] +145: model.layers.22.self_attn.v_proj.weight shape: [1024, 4096] +146: model.layers.23.input_layernorm.weight shape: [4096] +147: model.layers.23.mlp.down_proj.weight shape: [4096, 14336] +148: model.layers.23.mlp.gate_proj.weight shape: [14336, 4096] +149: model.layers.23.mlp.up_proj.weight shape: [14336, 4096] +150: model.layers.23.post_attention_layernorm.weight shape: [4096] +151: model.layers.23.self_attn.k_proj.weight shape: [1024, 4096] +152: model.layers.23.self_attn.o_proj.weight shape: [4096, 4096] +153: model.layers.23.self_attn.q_proj.weight shape: [4096, 4096] +154: model.layers.23.self_attn.v_proj.weight shape: [1024, 4096] +155: model.layers.24.input_layernorm.weight shape: [4096] +156: model.layers.24.mlp.down_proj.weight shape: [4096, 14336] +157: model.layers.24.mlp.gate_proj.weight shape: [14336, 4096] +158: model.layers.24.mlp.up_proj.weight shape: [14336, 4096] +159: model.layers.24.post_attention_layernorm.weight shape: [4096] +160: model.layers.24.self_attn.k_proj.weight shape: [1024, 4096] +161: model.layers.24.self_attn.o_proj.weight shape: [4096, 4096] +162: model.layers.24.self_attn.q_proj.weight shape: [4096, 4096] +163: model.layers.24.self_attn.v_proj.weight shape: [1024, 4096] +164: model.layers.25.input_layernorm.weight shape: [4096] +165: model.layers.25.mlp.down_proj.weight shape: [4096, 14336] +166: model.layers.25.mlp.gate_proj.weight shape: [14336, 4096] +167: model.layers.25.mlp.up_proj.weight shape: [14336, 4096] +168: model.layers.25.post_attention_layernorm.weight shape: [4096] +169: model.layers.25.self_attn.k_proj.weight shape: [1024, 4096] +170: model.layers.25.self_attn.o_proj.weight shape: [4096, 4096] +171: model.layers.25.self_attn.q_proj.weight shape: [4096, 4096] +172: model.layers.25.self_attn.v_proj.weight shape: [1024, 4096] +173: model.layers.26.input_layernorm.weight shape: [4096] +174: model.layers.26.mlp.down_proj.weight shape: [4096, 14336] +175: model.layers.26.mlp.gate_proj.weight shape: [14336, 4096] +176: model.layers.26.mlp.up_proj.weight shape: [14336, 4096] +177: model.layers.26.post_attention_layernorm.weight shape: [4096] +178: model.layers.26.self_attn.k_proj.weight shape: [1024, 4096] +179: model.layers.26.self_attn.o_proj.weight shape: [4096, 4096] +180: model.layers.26.self_attn.q_proj.weight shape: [4096, 4096] +181: model.layers.26.self_attn.v_proj.weight shape: [1024, 4096] +182: model.layers.27.input_layernorm.weight shape: [4096] +183: model.layers.27.mlp.down_proj.weight shape: [4096, 14336] +184: model.layers.27.mlp.gate_proj.weight shape: [14336, 4096] +185: model.layers.27.mlp.up_proj.weight shape: [14336, 4096] +186: model.layers.27.post_attention_layernorm.weight shape: [4096] +187: model.layers.27.self_attn.k_proj.weight shape: [1024, 4096] +188: model.layers.27.self_attn.o_proj.weight shape: [4096, 4096] +189: model.layers.27.self_attn.q_proj.weight shape: [4096, 4096] +190: model.layers.27.self_attn.v_proj.weight shape: [1024, 4096] +191: model.layers.28.input_layernorm.weight shape: [4096] +192: model.layers.28.mlp.down_proj.weight shape: [4096, 14336] +193: model.layers.28.mlp.gate_proj.weight shape: [14336, 4096] +194: model.layers.28.mlp.up_proj.weight shape: [14336, 4096] +195: model.layers.28.post_attention_layernorm.weight shape: [4096] +196: model.layers.28.self_attn.k_proj.weight shape: [1024, 4096] +197: model.layers.28.self_attn.o_proj.weight shape: [4096, 4096] +198: model.layers.28.self_attn.q_proj.weight shape: [4096, 4096] +199: model.layers.28.self_attn.v_proj.weight shape: [1024, 4096] +200: model.layers.29.input_layernorm.weight shape: [4096] +201: model.layers.29.mlp.down_proj.weight shape: [4096, 14336] +202: model.layers.29.mlp.gate_proj.weight shape: [14336, 4096] +203: model.layers.29.mlp.up_proj.weight shape: [14336, 4096] +204: model.layers.29.post_attention_layernorm.weight shape: [4096] +205: model.layers.29.self_attn.k_proj.weight shape: [1024, 4096] +206: model.layers.29.self_attn.o_proj.weight shape: [4096, 4096] +207: model.layers.29.self_attn.q_proj.weight shape: [4096, 4096] +208: model.layers.29.self_attn.v_proj.weight shape: [1024, 4096] +209: model.layers.3.input_layernorm.weight shape: [4096] +210: model.layers.3.mlp.down_proj.weight shape: [4096, 14336] +211: model.layers.3.mlp.gate_proj.weight shape: [14336, 4096] +212: model.layers.3.mlp.up_proj.weight shape: [14336, 4096] +213: model.layers.3.post_attention_layernorm.weight shape: [4096] +214: model.layers.3.self_attn.k_proj.weight shape: [1024, 4096] +215: model.layers.3.self_attn.o_proj.weight shape: [4096, 4096] +216: model.layers.3.self_attn.q_proj.weight shape: [4096, 4096] +217: model.layers.3.self_attn.v_proj.weight shape: [1024, 4096] +218: model.layers.30.input_layernorm.weight shape: [4096] +219: model.layers.30.mlp.down_proj.weight shape: [4096, 14336] +220: model.layers.30.mlp.gate_proj.weight shape: [14336, 4096] +221: model.layers.30.mlp.up_proj.weight shape: [14336, 4096] +222: model.layers.30.post_attention_layernorm.weight shape: [4096] +223: model.layers.30.self_attn.k_proj.weight shape: [1024, 4096] +224: model.layers.30.self_attn.o_proj.weight shape: [4096, 4096] +225: model.layers.30.self_attn.q_proj.weight shape: [4096, 4096] +226: model.layers.30.self_attn.v_proj.weight shape: [1024, 4096] +227: model.layers.31.input_layernorm.weight shape: [4096] +228: model.layers.31.mlp.down_proj.weight shape: [4096, 14336] +229: model.layers.31.mlp.gate_proj.weight shape: [14336, 4096] +230: model.layers.31.mlp.up_proj.weight shape: [14336, 4096] +231: model.layers.31.post_attention_layernorm.weight shape: [4096] +232: model.layers.31.self_attn.k_proj.weight shape: [1024, 4096] +233: model.layers.31.self_attn.o_proj.weight shape: [4096, 4096] +234: model.layers.31.self_attn.q_proj.weight shape: [4096, 4096] +235: model.layers.31.self_attn.v_proj.weight shape: [1024, 4096] +236: model.layers.4.input_layernorm.weight shape: [4096] +237: model.layers.4.mlp.down_proj.weight shape: [4096, 14336] +238: model.layers.4.mlp.gate_proj.weight shape: [14336, 4096] +239: model.layers.4.mlp.up_proj.weight shape: [14336, 4096] +240: model.layers.4.post_attention_layernorm.weight shape: [4096] +241: model.layers.4.self_attn.k_proj.weight shape: [1024, 4096] +242: model.layers.4.self_attn.o_proj.weight shape: [4096, 4096] +243: model.layers.4.self_attn.q_proj.weight shape: [4096, 4096] +244: model.layers.4.self_attn.v_proj.weight shape: [1024, 4096] +245: model.layers.5.input_layernorm.weight shape: [4096] +246: model.layers.5.mlp.down_proj.weight shape: [4096, 14336] +247: model.layers.5.mlp.gate_proj.weight shape: [14336, 4096] +248: model.layers.5.mlp.up_proj.weight shape: [14336, 4096] +249: model.layers.5.post_attention_layernorm.weight shape: [4096] +250: model.layers.5.self_attn.k_proj.weight shape: [1024, 4096] +251: model.layers.5.self_attn.o_proj.weight shape: [4096, 4096] +252: model.layers.5.self_attn.q_proj.weight shape: [4096, 4096] +253: model.layers.5.self_attn.v_proj.weight shape: [1024, 4096] +254: model.layers.6.input_layernorm.weight shape: [4096] +255: model.layers.6.mlp.down_proj.weight shape: [4096, 14336] +256: model.layers.6.mlp.gate_proj.weight shape: [14336, 4096] +257: model.layers.6.mlp.up_proj.weight shape: [14336, 4096] +258: model.layers.6.post_attention_layernorm.weight shape: [4096] +259: model.layers.6.self_attn.k_proj.weight shape: [1024, 4096] +260: model.layers.6.self_attn.o_proj.weight shape: [4096, 4096] +261: model.layers.6.self_attn.q_proj.weight shape: [4096, 4096] +262: model.layers.6.self_attn.v_proj.weight shape: [1024, 4096] +263: model.layers.7.input_layernorm.weight shape: [4096] +264: model.layers.7.mlp.down_proj.weight shape: [4096, 14336] +265: model.layers.7.mlp.gate_proj.weight shape: [14336, 4096] +266: model.layers.7.mlp.up_proj.weight shape: [14336, 4096] +267: model.layers.7.post_attention_layernorm.weight shape: [4096] +268: model.layers.7.self_attn.k_proj.weight shape: [1024, 4096] +269: model.layers.7.self_attn.o_proj.weight shape: [4096, 4096] +270: model.layers.7.self_attn.q_proj.weight shape: [4096, 4096] +271: model.layers.7.self_attn.v_proj.weight shape: [1024, 4096] +272: model.layers.8.input_layernorm.weight shape: [4096] +273: model.layers.8.mlp.down_proj.weight shape: [4096, 14336] +274: model.layers.8.mlp.gate_proj.weight shape: [14336, 4096] +275: model.layers.8.mlp.up_proj.weight shape: [14336, 4096] +276: model.layers.8.post_attention_layernorm.weight shape: [4096] +277: model.layers.8.self_attn.k_proj.weight shape: [1024, 4096] +278: model.layers.8.self_attn.o_proj.weight shape: [4096, 4096] +279: model.layers.8.self_attn.q_proj.weight shape: [4096, 4096] +280: model.layers.8.self_attn.v_proj.weight shape: [1024, 4096] +281: model.layers.9.input_layernorm.weight shape: [4096] +282: model.layers.9.mlp.down_proj.weight shape: [4096, 14336] +283: model.layers.9.mlp.gate_proj.weight shape: [14336, 4096] +284: model.layers.9.mlp.up_proj.weight shape: [14336, 4096] +285: model.layers.9.post_attention_layernorm.weight shape: [4096] +286: model.layers.9.self_attn.k_proj.weight shape: [1024, 4096] +287: model.layers.9.self_attn.o_proj.weight shape: [4096, 4096] +288: model.layers.9.self_attn.q_proj.weight shape: [4096, 4096] +289: model.layers.9.self_attn.v_proj.weight shape: [1024, 4096] +290: model.norm.weight shape: [4096] diff --git a/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.TokenizerTest.approved.txt b/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.TokenizerTest.approved.txt new file mode 100644 index 0000000000..fc8562c9e9 --- /dev/null +++ b/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.TokenizerTest.approved.txt @@ -0,0 +1,2 @@ + [{"type": "function", "function": {"name": "get_current_weather", "description": "Get the current weather", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "format": {"type": "string", "enum": ["celsius", "fahrenheit"], "description": "The temperature unit to use. Infer this from the users location."}}, "required": ["location", "format"]}}}] What's the weather like in Paris? [{"name": "get_current_temperature", "arguments": {"location": "Paris, France", "unit": "celsius"}, "id": "9Ae3bDc2F"}] {"content": 22.0, "call_id": "9Ae3bDc2F"} The current temperature in Paris is 22.0 degrees celsius. +1, 1, 6, 1501, 7567, 1891, 2032, 1113, 3396, 1316, 1113, 3396, 2032, 10598, 1629, 2032, 1113, 1295, 29498, 3790, 29498, 1537, 1991, 1316, 1113, 7286, 2032, 1113, 2226, 1040, 2636, 8854, 1316, 1113, 12206, 2032, 10598, 1891, 2032, 1113, 3582, 1316, 1113, 11491, 2032, 10598, 3501, 2032, 10598, 1891, 2032, 1113, 2195, 1316, 1113, 7286, 2032, 1113, 1782, 3758, 1072, 2433, 29493, 1085, 29491, 29489, 29491, 4420, 10454, 29493, 10229, 8474, 1113, 4530, 2032, 10598, 1891, 2032, 1113, 2195, 1316, 1113, 10825, 2032, 8135, 29485, 1958, 3938, 1316, 1113, 29490, 19425, 13075, 9651, 1113, 7286, 2032, 1113, 1782, 8409, 5796, 1066, 1706, 29491, 1328, 1410, 1224, 1245, 1040, 6211, 5491, 1379, 11549, 1113, 11661, 2032, 8135, 3501, 1316, 1113, 4530, 3010, 14879, 29561, 7, 3, 2592, 29510, 29481, 1040, 8854, 1505, 1065, 6233, 29572, 4, 5, 1501, 7567, 1629, 2032, 1113, 1295, 29498, 3790, 29498, 29475, 17329, 1316, 1113, 17452, 2032, 10598, 3501, 2032, 1113, 4684, 1046, 29493, 5611, 1316, 1113, 6074, 2032, 1113, 29485, 1958, 3938, 8474, 1113, 1081, 2032, 1113, 29542, 29509, 29474, 29538, 29494, 29525, 29485, 29518, 29533, 29507, 10925, 2, 8, 10598, 4557, 2032, 29473, 29518, 29518, 29491, 29502, 29493, 1113, 3613, 29498, 1081, 2032, 1113, 29542, 29509, 29474, 29538, 29494, 29525, 29485, 29518, 29533, 18163, 9, 1183, 2636, 8409, 1065, 6233, 1117, 29473, 29518, 29518, 29491, 29502, 11950, 1045, 1958, 3938, 29491, 2 diff --git a/test/Microsoft.ML.GenAI.Mistral.Tests/Mistral_V0_3Tests.cs b/test/Microsoft.ML.GenAI.Mistral.Tests/Mistral_7B_Instruct_V0_3Tests.cs similarity index 64% rename from test/Microsoft.ML.GenAI.Mistral.Tests/Mistral_V0_3Tests.cs rename to test/Microsoft.ML.GenAI.Mistral.Tests/Mistral_7B_Instruct_V0_3Tests.cs index 63017258bf..bf85a911e4 100644 --- a/test/Microsoft.ML.GenAI.Mistral.Tests/Mistral_V0_3Tests.cs +++ b/test/Microsoft.ML.GenAI.Mistral.Tests/Mistral_7B_Instruct_V0_3Tests.cs @@ -7,11 +7,52 @@ using ApprovalTests.Reporters; using ApprovalTests; using Xunit; +using TorchSharp; +using Microsoft.ML.GenAI.Core.Extension; +using AutoGen.Core; namespace Microsoft.ML.GenAI.Mistral.Tests; -public class Mistral_V0_3Tests +[Collection("NoParallelization")] +public class Mistral_7B_Instruct_V0_3Tests { + public Mistral_7B_Instruct_V0_3Tests() + { + if (Environment.GetEnvironmentVariable("HELIX_CORRELATION_ID") != null) + { + Approvals.UseAssemblyLocationForApprovedFiles(); + } + + torch.set_default_device("meta"); + } + + [Fact] + [UseReporter(typeof(DiffReporter))] + [UseApprovalSubdirectory("Approvals")] + public void Mistral_7B_Instruct_V0_3_ShapeTest() + { + var model = new MistralForCausalLM(MistralConfig.Mistral_7B_Instruct_v0_3); + var stateDictStr = model.PeekShape(); + Approvals.Verify(stateDictStr); + } + + [Fact] + [UseReporter(typeof(DiffReporter))] + [UseApprovalSubdirectory("Approvals")] + public void ItBuildChatTemplateFromAutoGenChatHistory() + { + var chatHistory = new List + { + new TextMessage(Role.System, "You are a helpful AI assistant."), + new TextMessage(Role.User, "Hello?"), + new TextMessage(Role.Assistant, "World!"), + }; + + var prompt = Mistral_7B_0_3ChatTemplateBuilder.Instance.BuildPrompt(chatHistory); + + Approvals.Verify(prompt); + } + [Fact] [UseReporter(typeof(DiffReporter))] [UseApprovalSubdirectory("Approvals")] From 51f6d6ca39117cad4bf9f348339f38eae625ec1b Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Tue, 3 Sep 2024 17:36:54 -0700 Subject: [PATCH 03/17] add tool call support --- .../Microsoft.ML.GenAI.Samples.csproj | 3 + .../Mistral/Mistral_7B_Instruct.cs | 59 ++++++- .../Microsoft.ML.GenAI.Samples/Program.cs | 2 +- .../Pipeline/CausalLMPipeline.cs | 7 + .../MistralCausalLMAgent.cs | 83 +++++++++- .../MistralTokenizerHelper.cs | 1 - .../Mistral_7B_0_3ChatTemplateBuilder.cs | 145 ++++++++++++++++-- ...emplateFromAutoGenChatHistory.approved.txt | 4 +- ...emplateFromAutoGenChatHistory.received.txt | 4 +- ...thToolsFromAutoGenChatHistory.approved.txt | 3 + ...thToolsFromAutoGenChatHistory.received.txt | 3 + .../Mistral_7B_Instruct_V0_3Tests.cs | 48 ++++++ 12 files changed, 341 insertions(+), 21 deletions(-) create mode 100644 test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateWithToolsFromAutoGenChatHistory.approved.txt create mode 100644 test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateWithToolsFromAutoGenChatHistory.received.txt diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj b/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj index fab7d39b57..596d149e38 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj @@ -5,6 +5,7 @@ net8.0 enable enable + true @@ -17,6 +18,8 @@ + + diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs index e91c608279..4afc3806b6 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs @@ -12,9 +12,21 @@ namespace Microsoft.ML.GenAI.Samples.Mistral; -internal class Mistral_7B_Instruct +public partial class Mistral_7B_Instruct { - public static async void Run() + private static Mistral_7B_Instruct instance = new Mistral_7B_Instruct(); + + /// + /// get weather from city + /// + /// + [Function] + public async Task GetWeather(string city) + { + return await Task.FromResult($"The weather in {city} is sunny."); + } + + public static async Task RunAsync() { var device = "cuda"; if (device == "cuda") @@ -39,9 +51,50 @@ public static async void Run() .RegisterPrintMessage(); var task = """ - Write a C# program to print the sum of two numbers. Use top-level statement, put code between ```csharp and ```. + How are you. """; await agent.SendAsync(task); } + + public async static Task WeatherChatAsync() + { + var device = "cuda"; + if (device == "cuda") + { + torch.InitializeDeviceType(DeviceType.CUDA); + } + + var defaultType = ScalarType.BFloat16; + torch.manual_seed(1); + torch.set_default_dtype(defaultType); + var weightFolder = @"C:\Users\xiaoyuz\source\repos\Mistral-7B-Instruct-v0.3"; + var configName = "config.json"; + var originalWeightFolder = Path.Combine(weightFolder); + + Console.WriteLine("Loading Mistral from huggingface model weight folder"); + var tokenizer = MistralTokenizerHelper.FromPretrained(originalWeightFolder); + var model = MistralForCausalLM.FromPretrained(weightFolder, configName, layersOnTargetDevice: -1); + + var pipeline = new CausalLMPipeline(tokenizer, model, device); + + var weatherChatMiddleware = new FunctionCallMiddleware( + functions: [instance.GetWeatherFunctionContract], + functionMap: new Dictionary>> + { + { instance.GetWeatherFunctionContract.Name!, instance.GetWeatherWrapper } + }); + + var agent = new MistralCausalLMAgent(pipeline, "assistant") + .RegisterStreamingMiddleware(weatherChatMiddleware) + .RegisterPrintMessage(); + + var task = "what is the weather in Seattle"; + var userMessage = new TextMessage(Role.User, task); + + var reply = await agent.SendAsync(userMessage); + + // generate further reply using tool call result; + await agent.SendAsync(chatHistory: [userMessage, reply]); + } } diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs index 2ebfb25d38..cf166c7552 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs @@ -2,4 +2,4 @@ using Microsoft.ML.GenAI.Samples.Mistral; using Microsoft.ML.GenAI.Samples.Phi3Mini; -Mistral_7B_Instruct.Run(); +await Mistral_7B_Instruct.WeatherChatAsync(); diff --git a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs index 33e0bab19c..72da7c21d7 100644 --- a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs +++ b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs @@ -266,6 +266,13 @@ public virtual IEnumerable GenerateStreaming( foreach (var (token, _) in this.GenerateStreaming(inputTensor, attentionMask, stopTokenIds.ToArray(), temperature: temperature, maxLen: maxLen)) { var tokenIds = token[0].to_type(ScalarType.Int32).data().ToArray(); + if (this.Tokenizer is LlamaTokenizer llamaTokenizer && llamaTokenizer.SpecialTokens?.FirstOrDefault(kv => kv.Value == tokenIds[0]) is { Key: string specialToken }) + { + Console.WriteLine($"token: {tokenIds[0]}"); + yield return specialToken; + continue; + } + var duplicateTokenString = this.Tokenizer.Decode(tokenIds.Concat(tokenIds)) ?? throw new InvalidOperationException("Failed to decode token ids"); var tokenString = this.Tokenizer.Decode(tokenIds) ?? throw new InvalidOperationException("Failed to decode token ids"); // replace the first occurrence of the token with the duplicate token diff --git a/src/Microsoft.ML.GenAI.Mistral/MistralCausalLMAgent.cs b/src/Microsoft.ML.GenAI.Mistral/MistralCausalLMAgent.cs index 1f25d31b09..f5677a73ef 100644 --- a/src/Microsoft.ML.GenAI.Mistral/MistralCausalLMAgent.cs +++ b/src/Microsoft.ML.GenAI.Mistral/MistralCausalLMAgent.cs @@ -3,6 +3,10 @@ // See the LICENSE file in the project root for more information. using System.Runtime.CompilerServices; +using System.Text; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization; using AutoGen.Core; using Microsoft.ML.GenAI.Core; using Microsoft.ML.Tokenizers; @@ -44,7 +48,7 @@ public Task GenerateReplyAsync(IEnumerable messages, Generat var systemMessage = new TextMessage(Role.System, _systemMessage, from: this.Name); messages = messages.Prepend(systemMessage); } - var input = _templateBuilder.BuildPrompt(messages); + var input = _templateBuilder.BuildPrompt(messages, options?.Functions); var maxLen = options?.MaxToken ?? 1024; var temperature = options?.Temperature ?? 0.7f; var stopTokenSequence = options?.StopSequence ?? []; @@ -56,6 +60,12 @@ public Task GenerateReplyAsync(IEnumerable messages, Generat temperature: temperature, stopSequences: stopTokenSequence) ?? throw new InvalidOperationException("Failed to generate a reply."); + // post-process the output for tool call + if (output.StartsWith("[TOOL_CALLS]")) + { + return Task.FromResult(ParseAsToolCallMessage(output)); + } + return Task.FromResult(new TextMessage(Role.Assistant, output, from: this.Name)); } @@ -71,19 +81,86 @@ public async IAsyncEnumerable GenerateStreamingReplyAsync( var systemMessage = new TextMessage(Role.System, _systemMessage, from: this.Name); messages = messages.Prepend(systemMessage); } - var input = _templateBuilder.BuildPrompt(messages); + var input = _templateBuilder.BuildPrompt(messages, options?.Functions); var maxLen = options?.MaxToken ?? 1024; var temperature = options?.Temperature ?? 0.7f; var stopTokenSequence = options?.StopSequence ?? []; stopTokenSequence = stopTokenSequence.Append(_stopSequence).ToArray(); + // only streaming the output when the output is not a tool call + // otherwise, we collect all the chunks and convert them to a tool call message at the end of the streaming + var sb = new StringBuilder(); + bool? isToolCall = null; foreach (var output in _pipeline.GenerateStreaming( input, maxLen: maxLen, temperature: temperature, stopSequences: stopTokenSequence)) { - yield return new TextMessageUpdate(Role.Assistant, output, from: this.Name); + if (isToolCall is null) + { + sb.Append(output); + var str = sb.ToString(); + if (!str.StartsWith("[TOOL_CALLS]".Substring(0, str.Length))) + { + yield return new TextMessageUpdate(Role.Assistant, output, from: this.Name); + isToolCall = false; + } + else if (str.StartsWith("[TOOL_CALLS]")) + { + isToolCall = true; + } + } + else if (isToolCall == false) + { + yield return new TextMessageUpdate(Role.Assistant, output, from: this.Name); + } + else + { + sb.Append(output); + } + } + + if (isToolCall == true) + { + var toolCallMessage = ParseAsToolCallMessage(sb.ToString()); + foreach (var toolCall in toolCallMessage.ToolCalls) + { + yield return new ToolCallMessageUpdate(toolCall.FunctionName, toolCall.FunctionArguments, from: this.Name); + } } } + + private class MistralToolCall + { + [JsonPropertyName("name")] + public string? Name { get; set; } + + [JsonPropertyName("arguments")] + public JsonObject? Arguments { get; set; } + } + + private ToolCallMessage ParseAsToolCallMessage(string content) + { + var json = content.Substring("[TOOL_CALLS]".Length).Trim(); + + // the json string should be a list of tool call messages + // e.g. [{"name": "get_current_weather", "parameters": {"location": "Seattle"}}] + var mistralToolCalls = JsonSerializer.Deserialize>(json) ?? throw new InvalidOperationException("Failed to deserialize tool calls."); + var toolCalls = mistralToolCalls + .Select(tc => new ToolCall(tc.Name!, JsonSerializer.Serialize(tc.Arguments)) { ToolCallId = this.GenerateToolCallId() }); + + return new ToolCallMessage(toolCalls, from: this.Name); + } + + /// + /// 9 random alphanumeric characters + /// + private string GenerateToolCallId(int length = 9) + { + const string chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; + var random = new Random(); + return new string(Enumerable.Repeat(chars, length) + .Select(s => s[random.Next(s.Length)]).ToArray()); + } } diff --git a/src/Microsoft.ML.GenAI.Mistral/MistralTokenizerHelper.cs b/src/Microsoft.ML.GenAI.Mistral/MistralTokenizerHelper.cs index 4da9091030..d1d8f46e22 100644 --- a/src/Microsoft.ML.GenAI.Mistral/MistralTokenizerHelper.cs +++ b/src/Microsoft.ML.GenAI.Mistral/MistralTokenizerHelper.cs @@ -61,7 +61,6 @@ public static LlamaTokenizer FromPretrained( { var specialTokens = new Dictionary { - { unknownSymbol, unknownSymbolId }, { startSymbol, startSymbolId }, { endSymbol, endSymbolId }, { startInstructionSymbol, startInstructionSymbolId }, diff --git a/src/Microsoft.ML.GenAI.Mistral/Mistral_7B_0_3ChatTemplateBuilder.cs b/src/Microsoft.ML.GenAI.Mistral/Mistral_7B_0_3ChatTemplateBuilder.cs index 25df3390be..8852f62da9 100644 --- a/src/Microsoft.ML.GenAI.Mistral/Mistral_7B_0_3ChatTemplateBuilder.cs +++ b/src/Microsoft.ML.GenAI.Mistral/Mistral_7B_0_3ChatTemplateBuilder.cs @@ -3,7 +3,11 @@ // See the LICENSE file in the project root for more information. using System.Text; +using System.Text.Json; +using System.Text.Json.Nodes; using AutoGen.Core; +using Json.Schema; +using Json.Schema.Generation; using Microsoft.ML.GenAI.Core; using Microsoft.SemanticKernel.ChatCompletion; @@ -16,7 +20,7 @@ namespace Microsoft.ML.GenAI.Mistral; public class Mistral_7B_0_3ChatTemplateBuilder : IChatTemplateBuilder #pragma warning restore MSML_GeneralName // This name should be PascalCased { - private const char Newline = '\n'; + private const string Newline = "\r\n"; public static Mistral_7B_0_3ChatTemplateBuilder Instance { get; } = new Mistral_7B_0_3ChatTemplateBuilder(); @@ -33,11 +37,10 @@ public string BuildPrompt(IEnumerable messages, IEnumerable [[user, assistant, user, assistant], [user]] - var firstSequence = messages.Take(messages.ToList().FindLastIndex(m => m.GetRole() == Role.User) + 1); - var secondSequence = messages.Skip(messages.ToList().FindLastIndex(m => m.GetRole() == Role.User) + 1); + var firstSequence = messages.Take(messages.ToList().FindLastIndex(m => m.GetRole() == Role.User)); + var secondSequence = messages.Skip(messages.ToList().FindLastIndex(m => m.GetRole() == Role.User)); var sb = new StringBuilder(); - sb.Append(""); foreach (var message in firstSequence) { // skip system @@ -46,25 +49,54 @@ public string BuildPrompt(IEnumerable messages, IEnumerable $"[INST]{content.Trim()}[/INST]", - _ when message.GetRole() == Role.Assistant => $"{content.Trim()}", + ToolCallMessage toolCallMessage => BuildFromToolCallMessage(toolCallMessage), + ToolCallResultMessage toolCallResultMessage => BuildFromToolCallResultMessage(toolCallResultMessage), + ToolCallAggregateMessage toolCallAggregateMessage => BuildFromAggregrateToolCallMessage(toolCallAggregateMessage), + TextMessage when message.GetRole() == Role.User => $"[INST]{content.Trim()}[/INST]", + TextMessage when message.GetRole() == Role.Assistant => $"{content.Trim()}", _ => throw new InvalidOperationException("Invalid role.") }); } + // insert [AVAILABLE TOOLS] section if tools are provided + if (tools?.Any() == true) + { + var schemas = tools.Select(t => new + { + type = "function", + function = new + { + name = t.Name, + description = t.Description, + parameters = BuildJsonSchemaFromFunctionContract(t) + } + }); + var schemaPrompt = JsonSerializer.Serialize(schemas); + + // add a space after the colon in json string so mistral can correctly generate the stop token after [TOOL_CALLS] symbol. + // This is probably because in the training data, all the tool call samples are separated by a space after the colon. + // e.g. [AVAILABLE_TOOLS][{"type": "function", "function": {.... + // instead of [AVAILABLE_TOOLS][{"type":"function","function":{.... + // Therefore when inferencing, we need to add a space after the colon in the json string to match with the training data. + schemaPrompt = schemaPrompt.Replace(":", ": "); + schemaPrompt = schemaPrompt.Replace(",", ", "); + sb.Append($"[AVAILABLE_TOOLS]{schemaPrompt}[/AVAILABLE_TOOLS]"); + } + foreach (var message in secondSequence) { - var role = message.GetRole()!.Value; var content = message.GetContent()!; sb.Append(message switch { - _ when message.GetRole() == Role.User && !string.IsNullOrEmpty(systemMessage) => $"[INST] {systemMessage} {Newline}{Newline}{content.Trim()}[/INST]", - _ when message.GetRole() == Role.User => $"[INST]{content.Trim()}[/INST]", - _ when message.GetRole() == Role.Assistant => $"{content.Trim()}", + ToolCallMessage toolCallMessage => BuildFromToolCallMessage(toolCallMessage), + ToolCallResultMessage toolCallResultMessage => BuildFromToolCallResultMessage(toolCallResultMessage), + ToolCallAggregateMessage toolCallAggregateMessage => BuildFromAggregrateToolCallMessage(toolCallAggregateMessage), + TextMessage when message.GetRole() == Role.User && !string.IsNullOrEmpty(systemMessage) => $"[INST]{systemMessage}{Newline}{Newline}{content.Trim()}[/INST]", + TextMessage when message.GetRole() == Role.User => $"[INST]{content.Trim()}[/INST]", + TextMessage when message.GetRole() == Role.Assistant => $"{content.Trim()}", _ => throw new InvalidOperationException("Invalid role.") }); } @@ -76,4 +108,95 @@ public string BuildPrompt(ChatHistory chatHistory) { throw new NotImplementedException(); } + + private string BuildFromToolCallMessage(ToolCallMessage message) + { + var toolCalls = message.ToolCalls; + if (toolCalls.Count() == 0) + { + return string.Empty; + } + else + { + var toolCallObjects = toolCalls.Select(tc => + new + { + name = tc.FunctionName, + arguments = JsonObject.Parse(tc.FunctionArguments), + id = tc.ToolCallId, + } + ); + + var toolCallJson = JsonSerializer.Serialize(toolCallObjects); + return $"[TOOL_CALLS]{toolCallJson}"; + } + } + + private string BuildFromToolCallResultMessage(ToolCallResultMessage message) + { + var toolCallResults = message.ToolCalls; + if (toolCallResults.Count() == 0) + { + return string.Empty; + } + else + { + var toolCallResultObjects = toolCallResults.Select(tc => + new + { + id = tc.ToolCallId, + content = tc.Result, + } + ); + + var toolCallResultJson = JsonSerializer.Serialize(toolCallResultObjects); + return $"[TOOL_RESULTS]{toolCallResultJson}[/TOOL_RESULTS]"; + } + } + + private string BuildFromAggregrateToolCallMessage(ToolCallAggregateMessage message) + { + var toolCallMessage = message.Message1; + var toolCallResultMessage = message.Message2; + + var toolCall = BuildFromToolCallMessage(toolCallMessage); + var toolCallResult = BuildFromToolCallResultMessage(toolCallResultMessage); + + return $"{toolCall}{toolCallResult}"; + } + + private JsonSchema BuildJsonSchemaFromFunctionContract(FunctionContract contract) + { + var requiredParameterNames = new List(); + var propertiesSchemas = new Dictionary(); + var propertySchemaBuilder = new JsonSchemaBuilder().Type(SchemaValueType.Object); + foreach (var param in contract.Parameters ?? []) + { + if (param.Name is null) + { + throw new InvalidOperationException("Parameter name cannot be null"); + } + + var schemaBuilder = new JsonSchemaBuilder().FromType(param.ParameterType ?? throw new ArgumentNullException(nameof(param.ParameterType))); + if (param.Description != null) + { + schemaBuilder = schemaBuilder.Description(param.Description); + } + + if (param.IsRequired) + { + requiredParameterNames.Add(param.Name); + } + + var schema = schemaBuilder.Build(); + propertiesSchemas[param.Name] = schema; + + } + propertySchemaBuilder = propertySchemaBuilder.Properties(propertiesSchemas); + propertySchemaBuilder = propertySchemaBuilder.Required(requiredParameterNames); + + var jsonSchema = propertySchemaBuilder.Build(); + + return jsonSchema; + } } diff --git a/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateFromAutoGenChatHistory.approved.txt b/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateFromAutoGenChatHistory.approved.txt index 51ad32bf9f..493b07d9ec 100644 --- a/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateFromAutoGenChatHistory.approved.txt +++ b/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateFromAutoGenChatHistory.approved.txt @@ -1 +1,3 @@ -[INST]Hello?[/INST]World! \ No newline at end of file +[INST]You are a helpful AI assistant. + +Hello?[/INST]World! \ No newline at end of file diff --git a/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateFromAutoGenChatHistory.received.txt b/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateFromAutoGenChatHistory.received.txt index 51ad32bf9f..493b07d9ec 100644 --- a/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateFromAutoGenChatHistory.received.txt +++ b/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateFromAutoGenChatHistory.received.txt @@ -1 +1,3 @@ -[INST]Hello?[/INST]World! \ No newline at end of file +[INST]You are a helpful AI assistant. + +Hello?[/INST]World! \ No newline at end of file diff --git a/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateWithToolsFromAutoGenChatHistory.approved.txt b/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateWithToolsFromAutoGenChatHistory.approved.txt new file mode 100644 index 0000000000..4731561ae7 --- /dev/null +++ b/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateWithToolsFromAutoGenChatHistory.approved.txt @@ -0,0 +1,3 @@ +[INST]What's the weather in Seattle?[/INST][TOOL_CALLS][{"name":"get_current_weather","arguments":{"location":"Seattle, WA"},"id":"9Ae3bDc2F"}][TOOL_RESULTS][{"id":"9Ae3bDc2F","content":"sunny"}][/TOOL_RESULTS]The current temperature in Seattle is 22.0 degrees celsius.[INST]What's the weather in New York?[/INST][TOOL_CALLS][{"name":"get_current_weather","arguments":{"location":"Seattle, WA"},"id":"9Ae3bDc2F"}][TOOL_RESULTS][{"id":"9Ae3bDc2F","content":"sunny"}][/TOOL_RESULTS]The current temperature in New York is 22.0 degrees celsius.[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_current_weather", "description": "Get the current weather", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}}, "required": ["location"]}}}][/AVAILABLE_TOOLS][INST]You are a helpful AI assistant. + +What's the weather in Paris?[/INST] \ No newline at end of file diff --git a/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateWithToolsFromAutoGenChatHistory.received.txt b/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateWithToolsFromAutoGenChatHistory.received.txt new file mode 100644 index 0000000000..4731561ae7 --- /dev/null +++ b/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateWithToolsFromAutoGenChatHistory.received.txt @@ -0,0 +1,3 @@ +[INST]What's the weather in Seattle?[/INST][TOOL_CALLS][{"name":"get_current_weather","arguments":{"location":"Seattle, WA"},"id":"9Ae3bDc2F"}][TOOL_RESULTS][{"id":"9Ae3bDc2F","content":"sunny"}][/TOOL_RESULTS]The current temperature in Seattle is 22.0 degrees celsius.[INST]What's the weather in New York?[/INST][TOOL_CALLS][{"name":"get_current_weather","arguments":{"location":"Seattle, WA"},"id":"9Ae3bDc2F"}][TOOL_RESULTS][{"id":"9Ae3bDc2F","content":"sunny"}][/TOOL_RESULTS]The current temperature in New York is 22.0 degrees celsius.[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_current_weather", "description": "Get the current weather", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}}, "required": ["location"]}}}][/AVAILABLE_TOOLS][INST]You are a helpful AI assistant. + +What's the weather in Paris?[/INST] \ No newline at end of file diff --git a/test/Microsoft.ML.GenAI.Mistral.Tests/Mistral_7B_Instruct_V0_3Tests.cs b/test/Microsoft.ML.GenAI.Mistral.Tests/Mistral_7B_Instruct_V0_3Tests.cs index bf85a911e4..70ac70cdc9 100644 --- a/test/Microsoft.ML.GenAI.Mistral.Tests/Mistral_7B_Instruct_V0_3Tests.cs +++ b/test/Microsoft.ML.GenAI.Mistral.Tests/Mistral_7B_Instruct_V0_3Tests.cs @@ -53,6 +53,54 @@ public void ItBuildChatTemplateFromAutoGenChatHistory() Approvals.Verify(prompt); } + [Fact] + [UseReporter(typeof(DiffReporter))] + [UseApprovalSubdirectory("Approvals")] + public void ItBuildChatTemplateWithToolsFromAutoGenChatHistory() + { + var getWeatherTool = new FunctionContract + { + Name = "get_current_weather", + Namespace = "weather", + Description = "Get the current weather", + Parameters = [ + new FunctionParameterContract + { + Name = "location", + ParameterType = typeof(string), + Description = "The city and state, e.g. San Francisco, CA", + IsRequired = true + } + ] + }; + + var getWeatherToolCall = new ToolCall("get_current_weather", "{\"location\": \"Seattle, WA\"}") { ToolCallId = "9Ae3bDc2F" }; + var getWeatherToolCallResult = new ToolCall("get_current_weather", "{\"temperature\": 22.0}", "sunny") { ToolCallId = "9Ae3bDc2F" }; + var toolCallMessage = new ToolCallMessage([getWeatherToolCall]); + var toolCallResultMessage = new ToolCallResultMessage([getWeatherToolCallResult]); + var aggregateToolCallMessage = new ToolCallAggregateMessage(toolCallMessage, toolCallResultMessage); + + var chatHistory = new List + { + new TextMessage(Role.System, "You are a helpful AI assistant."), + new TextMessage(Role.User, "What's the weather in Seattle?"), + toolCallMessage, + toolCallResultMessage, + new TextMessage(Role.Assistant, "The current temperature in Seattle is 22.0 degrees celsius."), + + // test tool call aggregate message for immediate tool call execution + new TextMessage(Role.User, "What's the weather in New York?"), + aggregateToolCallMessage, + new TextMessage(Role.Assistant, "The current temperature in New York is 22.0 degrees celsius."), + + new TextMessage(Role.User, "What's the weather in Paris?"), + }; + + var prompt = Mistral_7B_0_3ChatTemplateBuilder.Instance.BuildPrompt(chatHistory, [getWeatherTool]); + + Approvals.Verify(prompt); + } + [Fact] [UseReporter(typeof(DiffReporter))] [UseApprovalSubdirectory("Approvals")] From 4d7ffa917bddbb547c0342e7e245559f3e321a57 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Tue, 3 Sep 2024 17:47:37 -0700 Subject: [PATCH 04/17] update autogen to v 0.1.0 --- eng/Versions.props | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eng/Versions.props b/eng/Versions.props index 2510aa58f7..42242736bb 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -68,7 +68,7 @@ 2 2.3.1 1.4.1 - 0.0.15 + 0.1.0 1.15.0 0.102.7 2.2.1.1 From dd09e42eda606d10bc42da012681ce466c824ba5 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Tue, 3 Sep 2024 21:33:53 -0700 Subject: [PATCH 05/17] update autogen to 0.1.0 --- .../Microsoft.ML.GenAI.Samples.csproj | 1 - .../Mistral/Mistral_7B_Instruct.cs | 7 +------ src/Microsoft.ML.GenAI.LLaMA/LlamaCausalLMAgent.cs | 2 +- src/Microsoft.ML.GenAI.Mistral/MistralCausalLMAgent.cs | 2 +- src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMAgent.cs | 2 +- 5 files changed, 4 insertions(+), 10 deletions(-) diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj b/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj index 596d149e38..792391a59f 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj @@ -19,7 +19,6 @@ - diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs index 4afc3806b6..663401e979 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs @@ -1,9 +1,4 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; -using AutoGen.Core; +using AutoGen.Core; using Microsoft.ML.GenAI.Core; using Microsoft.ML.GenAI.Mistral; using Microsoft.ML.Tokenizers; diff --git a/src/Microsoft.ML.GenAI.LLaMA/LlamaCausalLMAgent.cs b/src/Microsoft.ML.GenAI.LLaMA/LlamaCausalLMAgent.cs index 5deabd6df2..d6593f445f 100644 --- a/src/Microsoft.ML.GenAI.LLaMA/LlamaCausalLMAgent.cs +++ b/src/Microsoft.ML.GenAI.LLaMA/LlamaCausalLMAgent.cs @@ -60,7 +60,7 @@ public Task GenerateReplyAsync(IEnumerable messages, Generat } #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously - public async IAsyncEnumerable GenerateStreamingReplyAsync( + public async IAsyncEnumerable GenerateStreamingReplyAsync( #pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously IEnumerable messages, GenerateReplyOptions? options = null, diff --git a/src/Microsoft.ML.GenAI.Mistral/MistralCausalLMAgent.cs b/src/Microsoft.ML.GenAI.Mistral/MistralCausalLMAgent.cs index f5677a73ef..e20d3b860b 100644 --- a/src/Microsoft.ML.GenAI.Mistral/MistralCausalLMAgent.cs +++ b/src/Microsoft.ML.GenAI.Mistral/MistralCausalLMAgent.cs @@ -70,7 +70,7 @@ public Task GenerateReplyAsync(IEnumerable messages, Generat } #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously - public async IAsyncEnumerable GenerateStreamingReplyAsync( + public async IAsyncEnumerable GenerateStreamingReplyAsync( #pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously IEnumerable messages, GenerateReplyOptions? options = null, diff --git a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMAgent.cs b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMAgent.cs index abe1e92716..2b9e93a4a0 100644 --- a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMAgent.cs +++ b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMAgent.cs @@ -50,7 +50,7 @@ public Task GenerateReplyAsync(IEnumerable messages, Generat } #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously - public async IAsyncEnumerable GenerateStreamingReplyAsync( + public async IAsyncEnumerable GenerateStreamingReplyAsync( #pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously IEnumerable messages, GenerateReplyOptions? options = null, From acb0d2e641fac612a52bfbb41dfe7a20e9d68764 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Wed, 4 Sep 2024 08:54:12 -0700 Subject: [PATCH 06/17] remove tests on non-x64 machien --- .../Microsoft.ML.GenAI.Mistral.Tests.csproj | 1 + 1 file changed, 1 insertion(+) diff --git a/test/Microsoft.ML.GenAI.Mistral.Tests/Microsoft.ML.GenAI.Mistral.Tests.csproj b/test/Microsoft.ML.GenAI.Mistral.Tests/Microsoft.ML.GenAI.Mistral.Tests.csproj index da2dde0bfc..4715947431 100644 --- a/test/Microsoft.ML.GenAI.Mistral.Tests/Microsoft.ML.GenAI.Mistral.Tests.csproj +++ b/test/Microsoft.ML.GenAI.Mistral.Tests/Microsoft.ML.GenAI.Mistral.Tests.csproj @@ -26,6 +26,7 @@ + From f5514d3a12871bdfe7bef1bb652d486ba655b622 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Wed, 4 Sep 2024 08:55:41 -0700 Subject: [PATCH 07/17] add file header --- .../Mistral_7B_Instruct_V0_3Tests.cs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/test/Microsoft.ML.GenAI.Mistral.Tests/Mistral_7B_Instruct_V0_3Tests.cs b/test/Microsoft.ML.GenAI.Mistral.Tests/Mistral_7B_Instruct_V0_3Tests.cs index 70ac70cdc9..5a4a1ee089 100644 --- a/test/Microsoft.ML.GenAI.Mistral.Tests/Mistral_7B_Instruct_V0_3Tests.cs +++ b/test/Microsoft.ML.GenAI.Mistral.Tests/Mistral_7B_Instruct_V0_3Tests.cs @@ -1,15 +1,15 @@ -using System; -using System.Collections.Generic; -using System.Linq; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + using System.Text; -using System.Threading.Tasks; +using ApprovalTests; using ApprovalTests.Namers; using ApprovalTests.Reporters; -using ApprovalTests; -using Xunit; -using TorchSharp; -using Microsoft.ML.GenAI.Core.Extension; using AutoGen.Core; +using Microsoft.ML.GenAI.Core.Extension; +using TorchSharp; +using Xunit; namespace Microsoft.ML.GenAI.Mistral.Tests; From c2dca6492d54153967781af38c336a43f50006c1 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Wed, 4 Sep 2024 10:03:03 -0700 Subject: [PATCH 08/17] update --- .../Mistral/Mistral_7B_Instruct.cs | 59 ++++++++++++++++++- .../Microsoft.ML.GenAI.Samples/Program.cs | 2 +- .../Pipeline/CausalLMPipeline.cs | 5 +- .../MistralForCausalLM.cs | 11 +++- .../MistralModel.cs | 2 +- 5 files changed, 74 insertions(+), 5 deletions(-) diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs index 663401e979..bbd0b0ab22 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs @@ -1,8 +1,11 @@ -using AutoGen.Core; +using System.Text.Json; +using AutoGen.Core; using Microsoft.ML.GenAI.Core; using Microsoft.ML.GenAI.Mistral; +using Microsoft.ML.GenAI.Mistral.Module; using Microsoft.ML.Tokenizers; using TorchSharp; +using TorchSharp.PyBridge; using static TorchSharp.torch; namespace Microsoft.ML.GenAI.Samples.Mistral; @@ -52,6 +55,60 @@ How are you. await agent.SendAsync(task); } + public static void Embedding() + { + var device = "cuda"; + if (device == "cuda") + { + torch.InitializeDeviceType(DeviceType.CUDA); + } + + var defaultType = ScalarType.Float32; + torch.manual_seed(1); + torch.set_default_dtype(defaultType); + var weightFolder = @"C:\Users\xiaoyuz\source\repos\bge-en-icl"; + var configName = "config.json"; + var originalWeightFolder = Path.Combine(weightFolder); + + Console.WriteLine("Loading Mistral from huggingface model weight folder"); + var tokenizer = MistralTokenizerHelper.FromPretrained(originalWeightFolder, modelName: "tokenizer.model"); + + var mistralConfig = JsonSerializer.Deserialize(File.ReadAllText(Path.Combine(weightFolder, configName))) ?? throw new ArgumentNullException(nameof(configName)); + var model = new MistralModel(mistralConfig); + model.load_checkpoint(weightFolder, "model.safetensors.index.json", strict: true, useTqdm: false); + model.to(device); + + var pipeline = new CausalLMPipeline(tokenizer, model, device); + + var query = """ + Given a web search query, retrieve relevant passages that answer the query. + what is a virtual interface + A virtual interface is a software-defined abstraction that mimics the behavior and characteristics of a physical network interface. It allows multiple logical network connections to share the same physical network interface, enabling efficient utilization of network resources. Virtual interfaces are commonly used in virtualization technologies such as virtual machines and containers to provide network connectivity without requiring dedicated hardware. They facilitate flexible network configurations and help in isolating network traffic for security and management purposes. + + Given a web search query, retrieve relevant passages that answer the query. + causes of back pain in female for a week + Back pain in females lasting a week can stem from various factors. Common causes include muscle strain due to lifting heavy objects or improper posture, spinal issues like herniated discs or osteoporosis, menstrual cramps causing referred pain, urinary tract infections, or pelvic inflammatory disease. Pregnancy-related changes can also contribute. Stress and lack of physical activity may exacerbate symptoms. Proper diagnosis by a healthcare professional is crucial for effective treatment and management. + + Given a web search query, retrieve relevant passages that answer the query. + how much protein should a female eat + + """; + + var document = """ + As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day. + """; + var queryEmbedding = pipeline.GenerateEmbeddingFromLastTokenPool(query); + var documentEmbedding = pipeline.GenerateEmbeddingFromLastTokenPool(document); + + var score = 0f; + foreach (var (q, d) in queryEmbedding.Zip(documentEmbedding)) + { + score += q * d * 100; + } + + Console.WriteLine($"The similarity score between query and document is {score}"); + } + public async static Task WeatherChatAsync() { var device = "cuda"; diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs index cf166c7552..cfbaec7591 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs @@ -2,4 +2,4 @@ using Microsoft.ML.GenAI.Samples.Mistral; using Microsoft.ML.GenAI.Samples.Phi3Mini; -await Mistral_7B_Instruct.WeatherChatAsync(); +Mistral_7B_Instruct.Embedding(); diff --git a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs index 72da7c21d7..7363dafa40 100644 --- a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs +++ b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs @@ -301,7 +301,10 @@ public float[] GenerateEmbeddingFromLastTokenPool(string prompt) var inputIds = this.Tokenizer.EncodeToIds(prompt); var inputTensor = torch.tensor(inputIds.ToArray(), dtype: ScalarType.Int64, device: this.Device).unsqueeze(0); var attentionMask = torch.ones_like(inputTensor, device: this.Device); - var input = new CausalLMModelInput(inputTensor, attentionMask, pastKeyValuesLength: 0); + var input = new CausalLMModelInput(inputTensor, attentionMask, pastKeyValuesLength: 0) + { + OverrideCache = new DynamicKVCache(), + }; var output = this.Model.forward(input); var lastTokenHiddenState = output.LastHiddenState[0, ^1]; diff --git a/src/Microsoft.ML.GenAI.Mistral/MistralForCausalLM.cs b/src/Microsoft.ML.GenAI.Mistral/MistralForCausalLM.cs index 3d3ac11bbd..590aef1f25 100644 --- a/src/Microsoft.ML.GenAI.Mistral/MistralForCausalLM.cs +++ b/src/Microsoft.ML.GenAI.Mistral/MistralForCausalLM.cs @@ -116,6 +116,15 @@ public static MistralForCausalLM FromPretrained( public void LoadSafeTensors(string modelFolder, string checkPointName = "model.safetensors.index.json") { - this.load_checkpoint(path: modelFolder, checkpointName: checkPointName, strict: true, useTqdm: false); + // print the shape of model + var shape = this.Peek(); + Console.WriteLine($"Model shape: {shape}"); + var loadedDictionary = new Dictionary(); + this.load_checkpoint(path: modelFolder, checkpointName: checkPointName, strict: false, loadedParameters: loadedDictionary, useTqdm: false); + + foreach (var (key, succeed) in loadedDictionary) + { + Console.WriteLine($"Loading {key} {(succeed ? "succeed" : "failed")}"); + } } } diff --git a/src/Microsoft.ML.GenAI.Mistral/MistralModel.cs b/src/Microsoft.ML.GenAI.Mistral/MistralModel.cs index 6c05fe53e9..cab7e6cc5a 100644 --- a/src/Microsoft.ML.GenAI.Mistral/MistralModel.cs +++ b/src/Microsoft.ML.GenAI.Mistral/MistralModel.cs @@ -9,7 +9,7 @@ namespace Microsoft.ML.GenAI.Mistral.Module; -internal class MistralModel : nn.Module +public class MistralModel : nn.Module { private readonly MistralConfig _config; private readonly int? _paddingIdx; From 11446f86638a05b3c25f5bc3361a134f43c866ea Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Wed, 4 Sep 2024 23:31:10 -0700 Subject: [PATCH 09/17] update --- .../MistralForCausalLM.cs | 20 +++++++++---------- .../MistralTokenizerHelper.cs | 11 +++++++++- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/src/Microsoft.ML.GenAI.Mistral/MistralForCausalLM.cs b/src/Microsoft.ML.GenAI.Mistral/MistralForCausalLM.cs index 590aef1f25..18d43e5317 100644 --- a/src/Microsoft.ML.GenAI.Mistral/MistralForCausalLM.cs +++ b/src/Microsoft.ML.GenAI.Mistral/MistralForCausalLM.cs @@ -107,6 +107,15 @@ public static MistralForCausalLM FromPretrained( model.LoadSafeTensors(modelFolder, checkPointName); + if (quantizeToInt8) + { + model.ToInt8QuantizeModule(); + } + else if (quantizeToInt4) + { + model.ToInt4QuantizeModule(); + } + model = model.ToDynamicLoadingModel(deviceMap, targetDevice); torch.set_default_device(originalDefaultDevice); @@ -116,15 +125,6 @@ public static MistralForCausalLM FromPretrained( public void LoadSafeTensors(string modelFolder, string checkPointName = "model.safetensors.index.json") { - // print the shape of model - var shape = this.Peek(); - Console.WriteLine($"Model shape: {shape}"); - var loadedDictionary = new Dictionary(); - this.load_checkpoint(path: modelFolder, checkpointName: checkPointName, strict: false, loadedParameters: loadedDictionary, useTqdm: false); - - foreach (var (key, succeed) in loadedDictionary) - { - Console.WriteLine($"Loading {key} {(succeed ? "succeed" : "failed")}"); - } + this.load_checkpoint(path: modelFolder, checkpointName: checkPointName, strict: false, useTqdm: false); } } diff --git a/src/Microsoft.ML.GenAI.Mistral/MistralTokenizerHelper.cs b/src/Microsoft.ML.GenAI.Mistral/MistralTokenizerHelper.cs index d1d8f46e22..3ed9a79780 100644 --- a/src/Microsoft.ML.GenAI.Mistral/MistralTokenizerHelper.cs +++ b/src/Microsoft.ML.GenAI.Mistral/MistralTokenizerHelper.cs @@ -57,7 +57,8 @@ public static LlamaTokenizer FromPretrained( int startToolResultSymbolId = 8, string endToolResultSymbol = EndToolResultSymbol, int endToolResultSymbolId = 9, - bool addPrecedingSpace = true) + bool addPrecedingSpace = true, + Dictionary? additionalSpecialTokens = null) { var specialTokens = new Dictionary { @@ -72,6 +73,14 @@ public static LlamaTokenizer FromPretrained( { endToolResultSymbol, endToolResultSymbolId } }; + if (additionalSpecialTokens != null) + { + foreach (var (key, value) in additionalSpecialTokens) + { + specialTokens[key] = value; + } + } + return FromPretrained( modelWeightFolder, modelName, From aadb6c01f6c8172569d2fcf76cba6b55c4326371 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Thu, 5 Sep 2024 15:42:05 -0700 Subject: [PATCH 10/17] update ml tokenizer test version --- eng/Versions.props | 2 +- .../Mistral_7B_Instruct_V0_3Tests.cs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/eng/Versions.props b/eng/Versions.props index 42242736bb..fde5bdbeb8 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -96,7 +96,7 @@ 0.0.13-test 0.0.6-test 0.0.7-test - 2.0.0-beta.24415.1 + 2.0.0-beta.24455.2 4.8.6 1.0.118 1.6.24 diff --git a/test/Microsoft.ML.GenAI.Mistral.Tests/Mistral_7B_Instruct_V0_3Tests.cs b/test/Microsoft.ML.GenAI.Mistral.Tests/Mistral_7B_Instruct_V0_3Tests.cs index 5a4a1ee089..e68de05dd6 100644 --- a/test/Microsoft.ML.GenAI.Mistral.Tests/Mistral_7B_Instruct_V0_3Tests.cs +++ b/test/Microsoft.ML.GenAI.Mistral.Tests/Mistral_7B_Instruct_V0_3Tests.cs @@ -106,7 +106,7 @@ public void ItBuildChatTemplateWithToolsFromAutoGenChatHistory() [UseApprovalSubdirectory("Approvals")] public void TokenizerTest() { - var modelWeightFolder = Path.Join("C:\\Users\\xiaoyuz\\source\\repos\\Mistral-7B-Instruct-v0.3"); + var modelWeightFolder = "Mistral"; var tokenizer = MistralTokenizerHelper.FromPretrained(modelWeightFolder); var messages = new string[] From 54cfaf948f70198d9e3eda9cedcf942289d7f3aa Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Thu, 5 Sep 2024 16:17:36 -0700 Subject: [PATCH 11/17] fix build error --- .../Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs index bbd0b0ab22..1a6c09e405 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs @@ -109,7 +109,7 @@ public static void Embedding() Console.WriteLine($"The similarity score between query and document is {score}"); } - public async static Task WeatherChatAsync() + public static async Task WeatherChatAsync() { var device = "cuda"; if (device == "cuda") From ad5981c6f3f3b4943586ccfdddfa8bb2b7a44fe6 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Fri, 6 Sep 2024 09:55:56 -0700 Subject: [PATCH 12/17] remove .receive.txt --- ...ests.ItBuildChatTemplateFromAutoGenChatHistory.received.txt | 3 --- ...ildChatTemplateWithToolsFromAutoGenChatHistory.received.txt | 3 --- 2 files changed, 6 deletions(-) delete mode 100644 test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateFromAutoGenChatHistory.received.txt delete mode 100644 test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateWithToolsFromAutoGenChatHistory.received.txt diff --git a/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateFromAutoGenChatHistory.received.txt b/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateFromAutoGenChatHistory.received.txt deleted file mode 100644 index 493b07d9ec..0000000000 --- a/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateFromAutoGenChatHistory.received.txt +++ /dev/null @@ -1,3 +0,0 @@ -[INST]You are a helpful AI assistant. - -Hello?[/INST]World! \ No newline at end of file diff --git a/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateWithToolsFromAutoGenChatHistory.received.txt b/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateWithToolsFromAutoGenChatHistory.received.txt deleted file mode 100644 index 4731561ae7..0000000000 --- a/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateWithToolsFromAutoGenChatHistory.received.txt +++ /dev/null @@ -1,3 +0,0 @@ -[INST]What's the weather in Seattle?[/INST][TOOL_CALLS][{"name":"get_current_weather","arguments":{"location":"Seattle, WA"},"id":"9Ae3bDc2F"}][TOOL_RESULTS][{"id":"9Ae3bDc2F","content":"sunny"}][/TOOL_RESULTS]The current temperature in Seattle is 22.0 degrees celsius.[INST]What's the weather in New York?[/INST][TOOL_CALLS][{"name":"get_current_weather","arguments":{"location":"Seattle, WA"},"id":"9Ae3bDc2F"}][TOOL_RESULTS][{"id":"9Ae3bDc2F","content":"sunny"}][/TOOL_RESULTS]The current temperature in New York is 22.0 degrees celsius.[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_current_weather", "description": "Get the current weather", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}}, "required": ["location"]}}}][/AVAILABLE_TOOLS][INST]You are a helpful AI assistant. - -What's the weather in Paris?[/INST] \ No newline at end of file From 94434bae6fe8aa071cf1b7a7a3ff23b0e93dd488 Mon Sep 17 00:00:00 2001 From: Xiaoyun Zhang Date: Mon, 9 Sep 2024 12:37:33 -0700 Subject: [PATCH 13/17] Update docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs Co-authored-by: Weihan Li <7604648+WeihanLi@users.noreply.github.com> --- .../Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs index 1a6c09e405..6f1b60b699 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs @@ -19,9 +19,9 @@ public partial class Mistral_7B_Instruct /// /// [Function] - public async Task GetWeather(string city) + public Task GetWeather(string city) { - return await Task.FromResult($"The weather in {city} is sunny."); + return Task.FromResult($"The weather in {city} is sunny."); } public static async Task RunAsync() From 68a212e80014cf480d796f245e94628e7aae0a95 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Mon, 9 Sep 2024 21:54:32 -0700 Subject: [PATCH 14/17] update --- .../Microsoft.ML.GenAI.Samples/Program.cs | 2 +- .../Pipeline/CausalLMPipeline.cs | 17 ++++++++++------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs index cfbaec7591..cf166c7552 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs @@ -2,4 +2,4 @@ using Microsoft.ML.GenAI.Samples.Mistral; using Microsoft.ML.GenAI.Samples.Phi3Mini; -Mistral_7B_Instruct.Embedding(); +await Mistral_7B_Instruct.WeatherChatAsync(); diff --git a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs index 7363dafa40..c368378337 100644 --- a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs +++ b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs @@ -266,15 +266,18 @@ public virtual IEnumerable GenerateStreaming( foreach (var (token, _) in this.GenerateStreaming(inputTensor, attentionMask, stopTokenIds.ToArray(), temperature: temperature, maxLen: maxLen)) { var tokenIds = token[0].to_type(ScalarType.Int32).data().ToArray(); - if (this.Tokenizer is LlamaTokenizer llamaTokenizer && llamaTokenizer.SpecialTokens?.FirstOrDefault(kv => kv.Value == tokenIds[0]) is { Key: string specialToken }) + var duplicateTokenString = this.Tokenizer switch { - Console.WriteLine($"token: {tokenIds[0]}"); - yield return specialToken; - continue; - } + SentencePieceBpeTokenizer bpeTokenizer => bpeTokenizer.Decode(tokenIds.Concat(tokenIds), considerSpecialTokens: true) ?? throw new InvalidOperationException("Failed to decode token ids"), + _ => this.Tokenizer.Decode(tokenIds.Concat(tokenIds)) ?? throw new InvalidOperationException("Failed to decode token ids"), + }; + + var tokenString = this.Tokenizer switch + { + SentencePieceBpeTokenizer bpeTokenizer => bpeTokenizer.Decode(tokenIds, considerSpecialTokens: true) ?? throw new InvalidOperationException("Failed to decode token ids"), + _ => this.Tokenizer.Decode(tokenIds) ?? throw new InvalidOperationException("Failed to decode token ids"), + }; - var duplicateTokenString = this.Tokenizer.Decode(tokenIds.Concat(tokenIds)) ?? throw new InvalidOperationException("Failed to decode token ids"); - var tokenString = this.Tokenizer.Decode(tokenIds) ?? throw new InvalidOperationException("Failed to decode token ids"); // replace the first occurrence of the token with the duplicate token tokenString = duplicateTokenString.Substring(tokenString.Length); From 1bc69bde4879049d11591360f487cef00facaca0 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Mon, 9 Sep 2024 22:07:20 -0700 Subject: [PATCH 15/17] set t to 0 --- .../Mistral/Mistral_7B_Instruct.cs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs index 6f1b60b699..25580090fe 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs @@ -144,7 +144,11 @@ public static async Task WeatherChatAsync() var task = "what is the weather in Seattle"; var userMessage = new TextMessage(Role.User, task); - var reply = await agent.SendAsync(userMessage); + var reply = await agent.GenerateReplyAsync(messages: [userMessage], + new GenerateReplyOptions + { + Temperature = 0f, + }); // generate further reply using tool call result; await agent.SendAsync(chatHistory: [userMessage, reply]); From 05f726de3b46d118880c65e3b0a9b833f51a41d4 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Tue, 10 Sep 2024 09:25:29 -0700 Subject: [PATCH 16/17] fix test --- .../Mistral_7B_Instruct_V0_3Tests.TokenizerTest.approved.txt | 4 ++-- .../Mistral_7B_Instruct_V0_3Tests.cs | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.TokenizerTest.approved.txt b/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.TokenizerTest.approved.txt index fc8562c9e9..0287bd2f22 100644 --- a/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.TokenizerTest.approved.txt +++ b/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.TokenizerTest.approved.txt @@ -1,2 +1,2 @@ - [{"type": "function", "function": {"name": "get_current_weather", "description": "Get the current weather", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "format": {"type": "string", "enum": ["celsius", "fahrenheit"], "description": "The temperature unit to use. Infer this from the users location."}}, "required": ["location", "format"]}}}] What's the weather like in Paris? [{"name": "get_current_temperature", "arguments": {"location": "Paris, France", "unit": "celsius"}, "id": "9Ae3bDc2F"}] {"content": 22.0, "call_id": "9Ae3bDc2F"} The current temperature in Paris is 22.0 degrees celsius. -1, 1, 6, 1501, 7567, 1891, 2032, 1113, 3396, 1316, 1113, 3396, 2032, 10598, 1629, 2032, 1113, 1295, 29498, 3790, 29498, 1537, 1991, 1316, 1113, 7286, 2032, 1113, 2226, 1040, 2636, 8854, 1316, 1113, 12206, 2032, 10598, 1891, 2032, 1113, 3582, 1316, 1113, 11491, 2032, 10598, 3501, 2032, 10598, 1891, 2032, 1113, 2195, 1316, 1113, 7286, 2032, 1113, 1782, 3758, 1072, 2433, 29493, 1085, 29491, 29489, 29491, 4420, 10454, 29493, 10229, 8474, 1113, 4530, 2032, 10598, 1891, 2032, 1113, 2195, 1316, 1113, 10825, 2032, 8135, 29485, 1958, 3938, 1316, 1113, 29490, 19425, 13075, 9651, 1113, 7286, 2032, 1113, 1782, 8409, 5796, 1066, 1706, 29491, 1328, 1410, 1224, 1245, 1040, 6211, 5491, 1379, 11549, 1113, 11661, 2032, 8135, 3501, 1316, 1113, 4530, 3010, 14879, 29561, 7, 3, 2592, 29510, 29481, 1040, 8854, 1505, 1065, 6233, 29572, 4, 5, 1501, 7567, 1629, 2032, 1113, 1295, 29498, 3790, 29498, 29475, 17329, 1316, 1113, 17452, 2032, 10598, 3501, 2032, 1113, 4684, 1046, 29493, 5611, 1316, 1113, 6074, 2032, 1113, 29485, 1958, 3938, 8474, 1113, 1081, 2032, 1113, 29542, 29509, 29474, 29538, 29494, 29525, 29485, 29518, 29533, 29507, 10925, 2, 8, 10598, 4557, 2032, 29473, 29518, 29518, 29491, 29502, 29493, 1113, 3613, 29498, 1081, 2032, 1113, 29542, 29509, 29474, 29538, 29494, 29525, 29485, 29518, 29533, 18163, 9, 1183, 2636, 8409, 1065, 6233, 1117, 29473, 29518, 29518, 29491, 29502, 11950, 1045, 1958, 3938, 29491, 2 +[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_current_weather", "description": "Get the current weather", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "format": {"type": "string", "enum": ["celsius", "fahrenheit"], "description": "The temperature unit to use. Infer this from the users location."}}, "required": ["location", "format"]}}}][/AVAILABLE_TOOLS][INST] What's the weather like in Paris?[/INST][TOOL_CALLS] [{"name": "get_current_temperature", "arguments": {"location": "Paris, France", "unit": "celsius"}, "id": "9Ae3bDc2F"}][TOOL_RESULTS] {"content": 22.0, "call_id": "9Ae3bDc2F"}[/TOOL_RESULTS] The current temperature in Paris is 22.0 degrees celsius. +1, 6, 1501, 7567, 1891, 2032, 1113, 3396, 1316, 1113, 3396, 2032, 10598, 1629, 2032, 1113, 1295, 29498, 3790, 29498, 1537, 1991, 1316, 1113, 7286, 2032, 1113, 2226, 1040, 2636, 8854, 1316, 1113, 12206, 2032, 10598, 1891, 2032, 1113, 3582, 1316, 1113, 11491, 2032, 10598, 3501, 2032, 10598, 1891, 2032, 1113, 2195, 1316, 1113, 7286, 2032, 1113, 1782, 3758, 1072, 2433, 29493, 1085, 29491, 29489, 29491, 4420, 10454, 29493, 10229, 8474, 1113, 4530, 2032, 10598, 1891, 2032, 1113, 2195, 1316, 1113, 10825, 2032, 8135, 29485, 1958, 3938, 1316, 1113, 29490, 19425, 13075, 9651, 1113, 7286, 2032, 1113, 1782, 8409, 5796, 1066, 1706, 29491, 1328, 1410, 1224, 1245, 1040, 6211, 5491, 1379, 11549, 1113, 11661, 2032, 8135, 3501, 1316, 1113, 4530, 3010, 14879, 29561, 7, 3, 2592, 29510, 29481, 1040, 8854, 1505, 1065, 6233, 29572, 4, 5, 1501, 7567, 1629, 2032, 1113, 1295, 29498, 3790, 29498, 29475, 17329, 1316, 1113, 17452, 2032, 10598, 3501, 2032, 1113, 4684, 1046, 29493, 5611, 1316, 1113, 6074, 2032, 1113, 29485, 1958, 3938, 8474, 1113, 1081, 2032, 1113, 29542, 29509, 29474, 29538, 29494, 29525, 29485, 29518, 29533, 29507, 10925, 2, 8, 10598, 4557, 2032, 29473, 29518, 29518, 29491, 29502, 29493, 1113, 3613, 29498, 1081, 2032, 1113, 29542, 29509, 29474, 29538, 29494, 29525, 29485, 29518, 29533, 18163, 9, 1183, 2636, 8409, 1065, 6233, 1117, 29473, 29518, 29518, 29491, 29502, 11950, 1045, 1958, 3938, 29491, 2 diff --git a/test/Microsoft.ML.GenAI.Mistral.Tests/Mistral_7B_Instruct_V0_3Tests.cs b/test/Microsoft.ML.GenAI.Mistral.Tests/Mistral_7B_Instruct_V0_3Tests.cs index e68de05dd6..0aa80e8880 100644 --- a/test/Microsoft.ML.GenAI.Mistral.Tests/Mistral_7B_Instruct_V0_3Tests.cs +++ b/test/Microsoft.ML.GenAI.Mistral.Tests/Mistral_7B_Instruct_V0_3Tests.cs @@ -125,8 +125,8 @@ public void TokenizerTest() var sb = new StringBuilder(); foreach (var message in messages) { - var tokenizeIds = tokenizer.EncodeToIds(message, true, false); - var decodeToString = tokenizer.Decode(tokenizeIds); + var tokenizeIds = tokenizer.EncodeToIds(message, false, false); + var decodeToString = tokenizer.Decode(tokenizeIds, considerSpecialTokens: true); sb.AppendLine(decodeToString); var tokenizedStr = string.Join(", ", tokenizeIds.Select(x => x.ToString())); From 7eb9f58f659fbb82f9ed0bb990091fc3f235bab7 Mon Sep 17 00:00:00 2001 From: Xiaoyun Zhang Date: Wed, 25 Sep 2024 14:51:44 -0700 Subject: [PATCH 17/17] Update Microsoft.ML.GenAI.Mistral.csproj --- src/Microsoft.ML.GenAI.Mistral/Microsoft.ML.GenAI.Mistral.csproj | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Microsoft.ML.GenAI.Mistral/Microsoft.ML.GenAI.Mistral.csproj b/src/Microsoft.ML.GenAI.Mistral/Microsoft.ML.GenAI.Mistral.csproj index 5b0cb0acc0..896f47e5b7 100644 --- a/src/Microsoft.ML.GenAI.Mistral/Microsoft.ML.GenAI.Mistral.csproj +++ b/src/Microsoft.ML.GenAI.Mistral/Microsoft.ML.GenAI.Mistral.csproj @@ -4,6 +4,7 @@ net6.0;net8.0 enable enable + true