Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GenAI] Add Mistral 7B Instruction V0.3 #7231

Merged
merged 19 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion Microsoft.ML.sln
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,11 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.GenAI.Core.Tes
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.GenAI.LLaMA", "src\Microsoft.ML.GenAI.LLaMA\Microsoft.ML.GenAI.LLaMA.csproj", "{0AA6D5CB-195F-457A-8792-4221E76E6C44}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.ML.GenAI.LLaMA.Tests", "test\Microsoft.ML.GenAI.LLaMA.Tests\Microsoft.ML.GenAI.LLaMA.Tests.csproj", "{D202353D-6FAF-4263-9A01-BDCFBC92391F}"
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.GenAI.LLaMA.Tests", "test\Microsoft.ML.GenAI.LLaMA.Tests\Microsoft.ML.GenAI.LLaMA.Tests.csproj", "{D202353D-6FAF-4263-9A01-BDCFBC92391F}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.ML.GenAI.Mistral", "src\Microsoft.ML.GenAI.Mistral\Microsoft.ML.GenAI.Mistral.csproj", "{2729CC66-7743-442B-B3A5-1F4F27F044A5}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.ML.GenAI.Mistral.Tests", "test\Microsoft.ML.GenAI.Mistral.Tests\Microsoft.ML.GenAI.Mistral.Tests.csproj", "{49264202-C90A-43F6-8C30-BDAEF2F1465A}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Expand Down Expand Up @@ -898,6 +902,22 @@ Global
{D202353D-6FAF-4263-9A01-BDCFBC92391F}.Release|Any CPU.Build.0 = Release|Any CPU
{D202353D-6FAF-4263-9A01-BDCFBC92391F}.Release|x64.ActiveCfg = Release|Any CPU
{D202353D-6FAF-4263-9A01-BDCFBC92391F}.Release|x64.Build.0 = Release|Any CPU
{2729CC66-7743-442B-B3A5-1F4F27F044A5}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{2729CC66-7743-442B-B3A5-1F4F27F044A5}.Debug|Any CPU.Build.0 = Debug|Any CPU
{2729CC66-7743-442B-B3A5-1F4F27F044A5}.Debug|x64.ActiveCfg = Debug|Any CPU
{2729CC66-7743-442B-B3A5-1F4F27F044A5}.Debug|x64.Build.0 = Debug|Any CPU
{2729CC66-7743-442B-B3A5-1F4F27F044A5}.Release|Any CPU.ActiveCfg = Release|Any CPU
{2729CC66-7743-442B-B3A5-1F4F27F044A5}.Release|Any CPU.Build.0 = Release|Any CPU
{2729CC66-7743-442B-B3A5-1F4F27F044A5}.Release|x64.ActiveCfg = Release|Any CPU
{2729CC66-7743-442B-B3A5-1F4F27F044A5}.Release|x64.Build.0 = Release|Any CPU
{49264202-C90A-43F6-8C30-BDAEF2F1465A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{49264202-C90A-43F6-8C30-BDAEF2F1465A}.Debug|Any CPU.Build.0 = Debug|Any CPU
{49264202-C90A-43F6-8C30-BDAEF2F1465A}.Debug|x64.ActiveCfg = Debug|Any CPU
{49264202-C90A-43F6-8C30-BDAEF2F1465A}.Debug|x64.Build.0 = Debug|Any CPU
{49264202-C90A-43F6-8C30-BDAEF2F1465A}.Release|Any CPU.ActiveCfg = Release|Any CPU
{49264202-C90A-43F6-8C30-BDAEF2F1465A}.Release|Any CPU.Build.0 = Release|Any CPU
{49264202-C90A-43F6-8C30-BDAEF2F1465A}.Release|x64.ActiveCfg = Release|Any CPU
{49264202-C90A-43F6-8C30-BDAEF2F1465A}.Release|x64.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down Expand Up @@ -991,6 +1011,8 @@ Global
{14AB0804-D4CE-4634-B544-5A8587620783} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{0AA6D5CB-195F-457A-8792-4221E76E6C44} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{D202353D-6FAF-4263-9A01-BDCFBC92391F} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{2729CC66-7743-442B-B3A5-1F4F27F044A5} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{49264202-C90A-43F6-8C30-BDAEF2F1465A} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,20 @@
<TargetFramework>net8.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<GenerateDocumentationFile>true</GenerateDocumentationFile>
</PropertyGroup>

<ItemGroup>
<ProjectReference Include="..\..\..\src\Microsoft.ML.GenAI.Core\Microsoft.ML.GenAI.Core.csproj" />
<ProjectReference Include="..\..\..\src\Microsoft.ML.GenAI.LLaMA\Microsoft.ML.GenAI.LLaMA.csproj" />
<ProjectReference Include="..\..\..\src\Microsoft.ML.GenAI.Mistral\Microsoft.ML.GenAI.Mistral.csproj" />
<ProjectReference Include="..\..\..\src\Microsoft.ML.GenAI.Phi\Microsoft.ML.GenAI.Phi.csproj" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="TorchSharp-cuda-windows" Version="0.102.5" Condition="$([MSBuild]::IsOSPlatform('Windows'))" />
<PackageReference Include="Microsoft.SemanticKernel" Version="$(SemanticKernelVersion)" />
<PackageReference Include="AutoGen.SourceGenerator" Version="$(AutoGenVersion)" />
</ItemGroup>

</Project>
156 changes: 156 additions & 0 deletions docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
using System.Text.Json;
using AutoGen.Core;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.GenAI.Mistral;
using Microsoft.ML.GenAI.Mistral.Module;
using Microsoft.ML.Tokenizers;
using TorchSharp;
using TorchSharp.PyBridge;
using static TorchSharp.torch;

namespace Microsoft.ML.GenAI.Samples.Mistral;

public partial class Mistral_7B_Instruct
{
private static Mistral_7B_Instruct instance = new Mistral_7B_Instruct();

/// <summary>
/// get weather from city
/// </summary>
/// <param name="city"></param>
[Function]
public Task<string> GetWeather(string city)
{
return Task.FromResult($"The weather in {city} is sunny.");
}

public static async Task RunAsync()
{
var device = "cuda";
if (device == "cuda")
{
torch.InitializeDeviceType(DeviceType.CUDA);
}

var defaultType = ScalarType.BFloat16;
torch.manual_seed(1);
torch.set_default_dtype(defaultType);
var weightFolder = @"C:\Users\xiaoyuz\source\repos\Mistral-7B-Instruct-v0.3";
var configName = "config.json";
var originalWeightFolder = Path.Combine(weightFolder);

Console.WriteLine("Loading Mistral from huggingface model weight folder");
var tokenizer = MistralTokenizerHelper.FromPretrained(originalWeightFolder);
var model = MistralForCausalLM.FromPretrained(weightFolder, configName, layersOnTargetDevice: -1);

var pipeline = new CausalLMPipeline<LlamaTokenizer, MistralForCausalLM>(tokenizer, model, device);

var agent = new MistralCausalLMAgent(pipeline, "assistant")
.RegisterPrintMessage();

var task = """
How are you.
""";

await agent.SendAsync(task);
}

public static void Embedding()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@luisquintanilla This will be how bge-en-icl embedding model consumed.

{
var device = "cuda";
if (device == "cuda")
{
torch.InitializeDeviceType(DeviceType.CUDA);
}

var defaultType = ScalarType.Float32;
torch.manual_seed(1);
torch.set_default_dtype(defaultType);
var weightFolder = @"C:\Users\xiaoyuz\source\repos\bge-en-icl";
var configName = "config.json";
var originalWeightFolder = Path.Combine(weightFolder);

Console.WriteLine("Loading Mistral from huggingface model weight folder");
var tokenizer = MistralTokenizerHelper.FromPretrained(originalWeightFolder, modelName: "tokenizer.model");

var mistralConfig = JsonSerializer.Deserialize<MistralConfig>(File.ReadAllText(Path.Combine(weightFolder, configName))) ?? throw new ArgumentNullException(nameof(configName));
var model = new MistralModel(mistralConfig);
model.load_checkpoint(weightFolder, "model.safetensors.index.json", strict: true, useTqdm: false);
model.to(device);

var pipeline = new CausalLMPipeline<LlamaTokenizer, MistralModel>(tokenizer, model, device);

var query = """
<instruct>Given a web search query, retrieve relevant passages that answer the query.
<query>what is a virtual interface
<response>A virtual interface is a software-defined abstraction that mimics the behavior and characteristics of a physical network interface. It allows multiple logical network connections to share the same physical network interface, enabling efficient utilization of network resources. Virtual interfaces are commonly used in virtualization technologies such as virtual machines and containers to provide network connectivity without requiring dedicated hardware. They facilitate flexible network configurations and help in isolating network traffic for security and management purposes.

<instruct>Given a web search query, retrieve relevant passages that answer the query.
<query>causes of back pain in female for a week
<response>Back pain in females lasting a week can stem from various factors. Common causes include muscle strain due to lifting heavy objects or improper posture, spinal issues like herniated discs or osteoporosis, menstrual cramps causing referred pain, urinary tract infections, or pelvic inflammatory disease. Pregnancy-related changes can also contribute. Stress and lack of physical activity may exacerbate symptoms. Proper diagnosis by a healthcare professional is crucial for effective treatment and management.

<instruct>Given a web search query, retrieve relevant passages that answer the query.
<query>how much protein should a female eat
<response>
""";

var document = """
As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.
""";
var queryEmbedding = pipeline.GenerateEmbeddingFromLastTokenPool(query);
var documentEmbedding = pipeline.GenerateEmbeddingFromLastTokenPool(document);

var score = 0f;
foreach (var (q, d) in queryEmbedding.Zip(documentEmbedding))
{
score += q * d * 100;
}

Console.WriteLine($"The similarity score between query and document is {score}");
}

public static async Task WeatherChatAsync()
{
var device = "cuda";
if (device == "cuda")
{
torch.InitializeDeviceType(DeviceType.CUDA);
}

var defaultType = ScalarType.BFloat16;
torch.manual_seed(1);
torch.set_default_dtype(defaultType);
var weightFolder = @"C:\Users\xiaoyuz\source\repos\Mistral-7B-Instruct-v0.3";
var configName = "config.json";
var originalWeightFolder = Path.Combine(weightFolder);

Console.WriteLine("Loading Mistral from huggingface model weight folder");
var tokenizer = MistralTokenizerHelper.FromPretrained(originalWeightFolder);
var model = MistralForCausalLM.FromPretrained(weightFolder, configName, layersOnTargetDevice: -1);

var pipeline = new CausalLMPipeline<LlamaTokenizer, MistralForCausalLM>(tokenizer, model, device);

var weatherChatMiddleware = new FunctionCallMiddleware(
functions: [instance.GetWeatherFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ instance.GetWeatherFunctionContract.Name!, instance.GetWeatherWrapper }
});

var agent = new MistralCausalLMAgent(pipeline, "assistant")
.RegisterStreamingMiddleware(weatherChatMiddleware)
.RegisterPrintMessage();

var task = "what is the weather in Seattle";
var userMessage = new TextMessage(Role.User, task);

var reply = await agent.GenerateReplyAsync(messages: [userMessage],
new GenerateReplyOptions
{
Temperature = 0f,
});

// generate further reply using tool call result;
await agent.SendAsync(chatHistory: [userMessage, reply]);
}
}
3 changes: 2 additions & 1 deletion docs/samples/Microsoft.ML.GenAI.Samples/Program.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// See https://aka.ms/new-console-template for more information
using Microsoft.ML.GenAI.Samples.Mistral;
using Microsoft.ML.GenAI.Samples.Phi3Mini;

await AutoGenSample.RunAsync();
await Mistral_7B_Instruct.WeatherChatAsync();
4 changes: 2 additions & 2 deletions eng/Versions.props
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
<TensorFlowMajorVersion>2</TensorFlowMajorVersion>
<TensorFlowVersion>2.3.1</TensorFlowVersion>
<TorchSharpPyBridgeVersion>1.4.1</TorchSharpPyBridgeVersion>
<AutoGenVersion>0.0.15</AutoGenVersion>
<AutoGenVersion>0.1.0</AutoGenVersion>
<SemanticKernelVersion>1.15.0</SemanticKernelVersion>
<TorchSharpVersion>0.102.7</TorchSharpVersion>
<LibTorchVersion>2.2.1.1</LibTorchVersion>
Expand Down Expand Up @@ -96,7 +96,7 @@
<MicrosoftMLTensorFlowTestModelsVersion>0.0.13-test</MicrosoftMLTensorFlowTestModelsVersion>
<MicrosoftMLTestDatabasesVersion>0.0.6-test</MicrosoftMLTestDatabasesVersion>
<MicrosoftMLTestModelsVersion>0.0.7-test</MicrosoftMLTestModelsVersion>
<MicrosoftMLTestTokenizersVersion>2.0.0-beta.24415.1</MicrosoftMLTestTokenizersVersion>
<MicrosoftMLTestTokenizersVersion>2.0.0-beta.24455.2</MicrosoftMLTestTokenizersVersion>
<SystemDataSqlClientVersion>4.8.6</SystemDataSqlClientVersion>
<SystemDataSQLiteCoreVersion>1.0.118</SystemDataSQLiteCoreVersion>
<XunitCombinatorialVersion>1.6.24</XunitCombinatorialVersion>
Expand Down
4 changes: 3 additions & 1 deletion src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@

<ItemGroup>
<InternalsVisibleTo Include="Microsoft.ML.GenAI.Phi" />
<InternalsVisibleTo Include="Microsoft.ML.GenAI.Phi.Tests" />
<InternalsVisibleTo Include="Microsoft.ML.GenAI.LLaMA" />
<InternalsVisibleTo Include="Microsoft.ML.GenAI.LLaMA.Tests" />
<InternalsVisibleTo Include="Microsoft.ML.GenAI.Phi.Tests" />
<InternalsVisibleTo Include="Microsoft.ML.GenAI.Mistral" />
<InternalsVisibleTo Include="Microsoft.ML.GenAI.Mistral.Tests" />
<InternalsVisibleTo Include="Microsoft.ML.GenAI.Core.Tests" />
</ItemGroup>

Expand Down
19 changes: 16 additions & 3 deletions src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,18 @@ public virtual IEnumerable<string> GenerateStreaming(
foreach (var (token, _) in this.GenerateStreaming(inputTensor, attentionMask, stopTokenIds.ToArray(), temperature: temperature, maxLen: maxLen))
{
var tokenIds = token[0].to_type(ScalarType.Int32).data<int>().ToArray();
var duplicateTokenString = this.Tokenizer.Decode(tokenIds.Concat(tokenIds)) ?? throw new InvalidOperationException("Failed to decode token ids");
var tokenString = this.Tokenizer.Decode(tokenIds) ?? throw new InvalidOperationException("Failed to decode token ids");
var duplicateTokenString = this.Tokenizer switch
{
SentencePieceBpeTokenizer bpeTokenizer => bpeTokenizer.Decode(tokenIds.Concat(tokenIds), considerSpecialTokens: true) ?? throw new InvalidOperationException("Failed to decode token ids"),
_ => this.Tokenizer.Decode(tokenIds.Concat(tokenIds)) ?? throw new InvalidOperationException("Failed to decode token ids"),
};

var tokenString = this.Tokenizer switch
{
SentencePieceBpeTokenizer bpeTokenizer => bpeTokenizer.Decode(tokenIds, considerSpecialTokens: true) ?? throw new InvalidOperationException("Failed to decode token ids"),
_ => this.Tokenizer.Decode(tokenIds) ?? throw new InvalidOperationException("Failed to decode token ids"),
};

// replace the first occurrence of the token with the duplicate token
tokenString = duplicateTokenString.Substring(tokenString.Length);

Expand All @@ -294,7 +304,10 @@ public float[] GenerateEmbeddingFromLastTokenPool(string prompt)
var inputIds = this.Tokenizer.EncodeToIds(prompt);
var inputTensor = torch.tensor(inputIds.ToArray(), dtype: ScalarType.Int64, device: this.Device).unsqueeze(0);
var attentionMask = torch.ones_like(inputTensor, device: this.Device);
var input = new CausalLMModelInput(inputTensor, attentionMask, pastKeyValuesLength: 0);
var input = new CausalLMModelInput(inputTensor, attentionMask, pastKeyValuesLength: 0)
{
OverrideCache = new DynamicKVCache(),
};
var output = this.Model.forward(input);
var lastTokenHiddenState = output.LastHiddenState[0, ^1];

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public interface ISemanticKernelChatTemplateBuilder

public interface IAutoGenChatTemplateBuilder
{
string BuildPrompt(IEnumerable<IMessage> messages);
string BuildPrompt(IEnumerable<IMessage> messages, IEnumerable<FunctionContract>? tools = null);
}

public interface IChatTemplateBuilder : IAutoGenChatTemplateBuilder, ISemanticKernelChatTemplateBuilder
Expand Down
15 changes: 15 additions & 0 deletions src/Microsoft.ML.GenAI.Core/Utils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text;
using System.Threading.Tasks;
using TorchSharp;
Expand Down Expand Up @@ -161,4 +162,18 @@ public static Tensor RepeatKV(Tensor x, int nRep)
.reshape(batchSize, nKVHeads * nRep, seqLen, headDim);
}

internal static string GetEmbeddedResource(string resourceName)
{
// read file content from embedded resource
var assembly = Assembly.GetCallingAssembly();
var resourceStream = assembly.GetManifestResourceStream(resourceName);

if (resourceStream == null)
{
throw new ArgumentException("Resource not found", resourceName);
}

using var reader = new System.IO.StreamReader(resourceStream);
return reader.ReadToEnd();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public class Llama3_1ChatTemplateBuilder : IChatTemplateBuilder
{
private const char Newline = '\n';

public string BuildPrompt(IEnumerable<IMessage> messages)
public string BuildPrompt(IEnumerable<IMessage> messages, IEnumerable<FunctionContract>? tools = null)
{
var availableRoles = new[] { Role.System, Role.User, Role.Assistant };
if (messages.Any(m => m.GetContent() is null))
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.GenAI.LLaMA/LlamaCausalLMAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, Generat
}

#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
public async IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(
public async IAsyncEnumerable<IMessage> GenerateStreamingReplyAsync(
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
Expand Down
6 changes: 0 additions & 6 deletions src/Microsoft.ML.GenAI.LLaMA/Module/LlamaModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.GenAI.Core.Extension;
using TorchSharp;
using TorchSharp.Modules;
using static TorchSharp.torch;
Expand Down
27 changes: 0 additions & 27 deletions src/Microsoft.ML.GenAI.LLaMA/Utils.cs

This file was deleted.

Loading
Loading