Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GenAI] Add generateEmbedding API to CausalLMPipeline #7227

Merged
merged 4 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using TorchSharp;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.GenAI.Core.Extension;
using Microsoft.ML.Tokenizers;

namespace Microsoft.ML.GenAI.Samples.Phi3Mini;

Expand All @@ -26,12 +27,15 @@ public static async Task RunAsync()
torch.manual_seed(1);
torch.set_default_dtype(defaultType);
var weightFolder = @"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct";
var pipeline = Utils.LoadPhi3Mini4KFromFolder(weightFolder, device: device, quantizeToInt8: false);
var tokenizerPath = Path.Combine(weightFolder, "tokenizer.model");
var tokenizer = Phi3TokenizerHelper.FromPretrained(tokenizerPath);
var model = Phi3ForCasualLM.FromPretrained(weightFolder, "config.json", layersOnTargetDevice: -1, quantizeToInt8: true);
var pipeline = new CausalLMPipeline<LlamaTokenizer, Phi3ForCasualLM>(tokenizer, model, device);
var question = @"write a C# program to calculate the factorial of a number";

// agent
var agent = new Phi3Agent(pipeline, "assistant")
.RegisterPrintMessage();
var question = @"write a C# program to calculate the factorial of a number";

// chat with the assistant
await agent.SendAsync(question);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
using Microsoft.ML.GenAI.Phi.Extension;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.GenAI.Phi;
using Microsoft.ML.GenAI.Phi.Extension;
using Microsoft.ML.Tokenizers;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using TorchSharp;
Expand All @@ -20,8 +23,10 @@ public static async Task RunChatCompletionSample()
torch.manual_seed(1);
torch.set_default_dtype(defaultType);
var weightFolder = @"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct";
var pipeline = Utils.LoadPhi3Mini4KFromFolder(weightFolder, device: device);

var tokenizerPath = Path.Combine(weightFolder, "tokenizer.model");
var tokenizer = Phi3TokenizerHelper.FromPretrained(tokenizerPath);
var model = Phi3ForCasualLM.FromPretrained(weightFolder, "config.json", layersOnTargetDevice: -1, quantizeToInt8: true);
var pipeline = new CausalLMPipeline<LlamaTokenizer, Phi3ForCasualLM>(tokenizer, model, device);

var kernel = Kernel.CreateBuilder()
.AddGenAIChatCompletion(pipeline)
Expand Down Expand Up @@ -49,8 +54,10 @@ public static async Task RunTextGenerationSample()
torch.manual_seed(1);
torch.set_default_dtype(defaultType);
var weightFolder = @"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct";
var pipeline = Utils.LoadPhi3Mini4KFromFolder(weightFolder, device);

var tokenizerPath = Path.Combine(weightFolder, "tokenizer.model");
var tokenizer = Phi3TokenizerHelper.FromPretrained(tokenizerPath);
var model = Phi3ForCasualLM.FromPretrained(weightFolder, "config.json", layersOnTargetDevice: -1, quantizeToInt8: true);
var pipeline = new CausalLMPipeline<LlamaTokenizer, Phi3ForCasualLM>(tokenizer, model, device);

var kernel = Kernel.CreateBuilder()
.AddGenAITextGeneration(pipeline)
Expand Down
103 changes: 0 additions & 103 deletions docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs

This file was deleted.

2 changes: 1 addition & 1 deletion docs/samples/Microsoft.ML.GenAI.Samples/Program.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// See https://aka.ms/new-console-template for more information
using Microsoft.ML.GenAI.Samples.Phi3Mini;

await SemanticKernelSample.RunChatCompletionSample();
await AutoGenSample.RunAsync();
24 changes: 24 additions & 0 deletions src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ string Generate(
float topP = CausalLMPipeline.Defaults.TopP,
string[]? stopSequences = CausalLMPipeline.Defaults.StopSequence);

/// <summary>
/// Generate the embedding(last hidden state of the last token) for the prompt. The embedding is normalized by L2 norm.
/// </summary>
float[] GenerateEmbeddingFromLastTokenPool(string prompt);

IEnumerable<string> GenerateStreaming(
string prompt,
int maxLen = CausalLMPipeline.Defaults.MaxLen,
Expand Down Expand Up @@ -281,4 +286,23 @@ protected torch.Tensor SampleTopP(torch.Tensor logits, float topP)
nextToken = torch.gather(probsIndex, dim: -1, index: nextToken);
return nextToken;
}

public float[] GenerateEmbeddingFromLastTokenPool(string prompt)
{
using var scope = NewDisposeScope();
using var noGrad = torch.no_grad();
var inputIds = this.Tokenizer.EncodeToIds(prompt);
var inputTensor = torch.tensor(inputIds.ToArray(), dtype: ScalarType.Int64, device: this.Device).unsqueeze(0);
var attentionMask = torch.ones_like(inputTensor, device: this.Device);
var input = new CausalLMModelInput(inputTensor, attentionMask, pastKeyValuesLength: 0);
var output = this.Model.forward(input);
var lastTokenHiddenState = output.LastHiddenState[0, ^1];

// shape of lastTokenHiddenState: [hidden_size]
// L2 norm
var norm = lastTokenHiddenState.norm();
var normalized = lastTokenHiddenState / norm;

return normalized.to_type(ScalarType.Float32).data<float>().ToArray();
}
}
50 changes: 50 additions & 0 deletions src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ForCasualLM.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Text.Json;
using System.Threading.Tasks;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.GenAI.Core.Extension;
using Microsoft.ML.GenAI.Phi.Module;
using TorchSharp;
using TorchSharp.Modules;
Expand Down Expand Up @@ -66,6 +67,55 @@ public static Phi3ForCasualLM FromPretrained(
return phi;
}

public static Phi3ForCasualLM FromPretrained(
string modelFolder,
string configName = "config.json",
string checkPointName = "model.safetensors.index.json",
bool quantizeToInt8 = false,
bool quantizeToInt4 = false,
int layersOnTargetDevice = -1,
ScalarType torchDtype = ScalarType.BFloat16,
string targetDevice = "cuda")
{
if (layersOnTargetDevice == -1 && quantizeToInt4 == false && quantizeToInt8 == false)
{
return FromPretrained(modelFolder, configName, checkPointName, torchDtype, targetDevice);
}

var originalDefaultDevice = torch.get_default_device();
torch.set_default_device("meta");
var config = Path.Join(modelFolder, configName);
var modelConfig = JsonSerializer.Deserialize<Phi3Config>(File.ReadAllText(config)) ?? throw new ArgumentNullException(nameof(config));
modelConfig.DType = torchDtype;
var model = new Phi3ForCasualLM(modelConfig);

if (quantizeToInt8)
{
model.ToInt8QuantizeModule();
}
else if (quantizeToInt4)
{
model.ToInt4QuantizeModule();
}

var deviceMap = model.InferDeviceMapForEachLayer(
[
KeyValuePair.Create(targetDevice, layersOnTargetDevice),
KeyValuePair.Create("cpu", -1)
]);

torch.set_default_device("cpu");
model = new Phi3ForCasualLM(modelConfig);

model.LoadSafeTensors(modelFolder, checkPointName);

model = model.ToDynamicLoadingModel(deviceMap, targetDevice);

torch.set_default_device(originalDefaultDevice);

return model;
}

public void LoadSafeTensors(string modelFolder, string checkPointName = "model.safetensors.index.json")
{
this.load_checkpoint(path: modelFolder, checkpointName: checkPointName, strict: false, useTqdm: false);
Expand Down
Loading