-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GenAI] Add Mistral 7B Instruction V0.3 #7231
Merged
Merged
Changes from 18 commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
6e005eb
add mistral and tests
LittleLittleCloud 7bd9338
add test and sample
LittleLittleCloud 51f6d6c
add tool call support
LittleLittleCloud 4d7ffa9
update autogen to v 0.1.0
LittleLittleCloud dd09e42
update autogen to 0.1.0
LittleLittleCloud acb0d2e
remove tests on non-x64 machien
LittleLittleCloud f5514d3
add file header
LittleLittleCloud c2dca64
update
LittleLittleCloud 11446f8
update
LittleLittleCloud aadb6c0
update ml tokenizer test version
LittleLittleCloud ae736cf
Merge branch 'u/mistral' of https://github.com/LittleLittleCloud/mach…
LittleLittleCloud 54cfaf9
fix build error
LittleLittleCloud ad5981c
remove .receive.txt
LittleLittleCloud 94434ba
Update docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Ins…
LittleLittleCloud 92dd526
Merge branch 'main' into u/mistral
LittleLittleCloud 68a212e
update
LittleLittleCloud 1bc69bd
set t to 0
LittleLittleCloud 05f726d
fix test
LittleLittleCloud 7eb9f58
Update Microsoft.ML.GenAI.Mistral.csproj
LittleLittleCloud File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
156 changes: 156 additions & 0 deletions
156
docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
using System.Text.Json; | ||
using AutoGen.Core; | ||
using Microsoft.ML.GenAI.Core; | ||
using Microsoft.ML.GenAI.Mistral; | ||
using Microsoft.ML.GenAI.Mistral.Module; | ||
using Microsoft.ML.Tokenizers; | ||
using TorchSharp; | ||
using TorchSharp.PyBridge; | ||
using static TorchSharp.torch; | ||
|
||
namespace Microsoft.ML.GenAI.Samples.Mistral; | ||
|
||
public partial class Mistral_7B_Instruct | ||
{ | ||
private static Mistral_7B_Instruct instance = new Mistral_7B_Instruct(); | ||
|
||
/// <summary> | ||
/// get weather from city | ||
/// </summary> | ||
/// <param name="city"></param> | ||
[Function] | ||
public Task<string> GetWeather(string city) | ||
{ | ||
return Task.FromResult($"The weather in {city} is sunny."); | ||
} | ||
|
||
public static async Task RunAsync() | ||
{ | ||
var device = "cuda"; | ||
if (device == "cuda") | ||
{ | ||
torch.InitializeDeviceType(DeviceType.CUDA); | ||
} | ||
|
||
var defaultType = ScalarType.BFloat16; | ||
torch.manual_seed(1); | ||
torch.set_default_dtype(defaultType); | ||
var weightFolder = @"C:\Users\xiaoyuz\source\repos\Mistral-7B-Instruct-v0.3"; | ||
var configName = "config.json"; | ||
var originalWeightFolder = Path.Combine(weightFolder); | ||
|
||
Console.WriteLine("Loading Mistral from huggingface model weight folder"); | ||
var tokenizer = MistralTokenizerHelper.FromPretrained(originalWeightFolder); | ||
var model = MistralForCausalLM.FromPretrained(weightFolder, configName, layersOnTargetDevice: -1); | ||
|
||
var pipeline = new CausalLMPipeline<LlamaTokenizer, MistralForCausalLM>(tokenizer, model, device); | ||
|
||
var agent = new MistralCausalLMAgent(pipeline, "assistant") | ||
.RegisterPrintMessage(); | ||
|
||
var task = """ | ||
How are you. | ||
"""; | ||
|
||
await agent.SendAsync(task); | ||
} | ||
|
||
public static void Embedding() | ||
{ | ||
var device = "cuda"; | ||
if (device == "cuda") | ||
{ | ||
torch.InitializeDeviceType(DeviceType.CUDA); | ||
} | ||
|
||
var defaultType = ScalarType.Float32; | ||
torch.manual_seed(1); | ||
torch.set_default_dtype(defaultType); | ||
var weightFolder = @"C:\Users\xiaoyuz\source\repos\bge-en-icl"; | ||
var configName = "config.json"; | ||
var originalWeightFolder = Path.Combine(weightFolder); | ||
|
||
Console.WriteLine("Loading Mistral from huggingface model weight folder"); | ||
var tokenizer = MistralTokenizerHelper.FromPretrained(originalWeightFolder, modelName: "tokenizer.model"); | ||
|
||
var mistralConfig = JsonSerializer.Deserialize<MistralConfig>(File.ReadAllText(Path.Combine(weightFolder, configName))) ?? throw new ArgumentNullException(nameof(configName)); | ||
var model = new MistralModel(mistralConfig); | ||
model.load_checkpoint(weightFolder, "model.safetensors.index.json", strict: true, useTqdm: false); | ||
model.to(device); | ||
|
||
var pipeline = new CausalLMPipeline<LlamaTokenizer, MistralModel>(tokenizer, model, device); | ||
|
||
var query = """ | ||
<instruct>Given a web search query, retrieve relevant passages that answer the query. | ||
<query>what is a virtual interface | ||
<response>A virtual interface is a software-defined abstraction that mimics the behavior and characteristics of a physical network interface. It allows multiple logical network connections to share the same physical network interface, enabling efficient utilization of network resources. Virtual interfaces are commonly used in virtualization technologies such as virtual machines and containers to provide network connectivity without requiring dedicated hardware. They facilitate flexible network configurations and help in isolating network traffic for security and management purposes. | ||
|
||
<instruct>Given a web search query, retrieve relevant passages that answer the query. | ||
<query>causes of back pain in female for a week | ||
<response>Back pain in females lasting a week can stem from various factors. Common causes include muscle strain due to lifting heavy objects or improper posture, spinal issues like herniated discs or osteoporosis, menstrual cramps causing referred pain, urinary tract infections, or pelvic inflammatory disease. Pregnancy-related changes can also contribute. Stress and lack of physical activity may exacerbate symptoms. Proper diagnosis by a healthcare professional is crucial for effective treatment and management. | ||
|
||
<instruct>Given a web search query, retrieve relevant passages that answer the query. | ||
<query>how much protein should a female eat | ||
<response> | ||
"""; | ||
|
||
var document = """ | ||
As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day. | ||
"""; | ||
var queryEmbedding = pipeline.GenerateEmbeddingFromLastTokenPool(query); | ||
var documentEmbedding = pipeline.GenerateEmbeddingFromLastTokenPool(document); | ||
|
||
var score = 0f; | ||
foreach (var (q, d) in queryEmbedding.Zip(documentEmbedding)) | ||
{ | ||
score += q * d * 100; | ||
} | ||
|
||
Console.WriteLine($"The similarity score between query and document is {score}"); | ||
} | ||
|
||
public static async Task WeatherChatAsync() | ||
{ | ||
var device = "cuda"; | ||
if (device == "cuda") | ||
{ | ||
torch.InitializeDeviceType(DeviceType.CUDA); | ||
} | ||
|
||
var defaultType = ScalarType.BFloat16; | ||
torch.manual_seed(1); | ||
torch.set_default_dtype(defaultType); | ||
var weightFolder = @"C:\Users\xiaoyuz\source\repos\Mistral-7B-Instruct-v0.3"; | ||
var configName = "config.json"; | ||
var originalWeightFolder = Path.Combine(weightFolder); | ||
|
||
Console.WriteLine("Loading Mistral from huggingface model weight folder"); | ||
var tokenizer = MistralTokenizerHelper.FromPretrained(originalWeightFolder); | ||
var model = MistralForCausalLM.FromPretrained(weightFolder, configName, layersOnTargetDevice: -1); | ||
|
||
var pipeline = new CausalLMPipeline<LlamaTokenizer, MistralForCausalLM>(tokenizer, model, device); | ||
|
||
var weatherChatMiddleware = new FunctionCallMiddleware( | ||
functions: [instance.GetWeatherFunctionContract], | ||
functionMap: new Dictionary<string, Func<string, Task<string>>> | ||
{ | ||
{ instance.GetWeatherFunctionContract.Name!, instance.GetWeatherWrapper } | ||
}); | ||
|
||
var agent = new MistralCausalLMAgent(pipeline, "assistant") | ||
.RegisterStreamingMiddleware(weatherChatMiddleware) | ||
.RegisterPrintMessage(); | ||
|
||
var task = "what is the weather in Seattle"; | ||
var userMessage = new TextMessage(Role.User, task); | ||
|
||
var reply = await agent.GenerateReplyAsync(messages: [userMessage], | ||
new GenerateReplyOptions | ||
{ | ||
Temperature = 0f, | ||
}); | ||
|
||
// generate further reply using tool call result; | ||
await agent.SendAsync(chatHistory: [userMessage, reply]); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
// See https://aka.ms/new-console-template for more information | ||
using Microsoft.ML.GenAI.Samples.Mistral; | ||
using Microsoft.ML.GenAI.Samples.Phi3Mini; | ||
|
||
await AutoGenSample.RunAsync(); | ||
await Mistral_7B_Instruct.WeatherChatAsync(); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@luisquintanilla This will be how bge-en-icl embedding model consumed.