Skip to content

Commit

Permalink
[GenAI] Add generateEmbedding API to CausalLMPipeline (#7227)
Browse files Browse the repository at this point in the history
* add embedding

* add frompretrain api to phi3 model

* fix bug

* Update CausalLMPipeline.cs
  • Loading branch information
LittleLittleCloud authored Aug 30, 2024
1 parent 1d1cc99 commit 7c937bf
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 111 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using TorchSharp;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.GenAI.Core.Extension;
using Microsoft.ML.Tokenizers;

namespace Microsoft.ML.GenAI.Samples.Phi3Mini;

Expand All @@ -26,12 +27,15 @@ public static async Task RunAsync()
torch.manual_seed(1);
torch.set_default_dtype(defaultType);
var weightFolder = @"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct";
var pipeline = Utils.LoadPhi3Mini4KFromFolder(weightFolder, device: device, quantizeToInt8: false);
var tokenizerPath = Path.Combine(weightFolder, "tokenizer.model");
var tokenizer = Phi3TokenizerHelper.FromPretrained(tokenizerPath);
var model = Phi3ForCasualLM.FromPretrained(weightFolder, "config.json", layersOnTargetDevice: -1, quantizeToInt8: true);
var pipeline = new CausalLMPipeline<LlamaTokenizer, Phi3ForCasualLM>(tokenizer, model, device);
var question = @"write a C# program to calculate the factorial of a number";

// agent
var agent = new Phi3Agent(pipeline, "assistant")
.RegisterPrintMessage();
var question = @"write a C# program to calculate the factorial of a number";

// chat with the assistant
await agent.SendAsync(question);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
using Microsoft.ML.GenAI.Phi.Extension;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.GenAI.Phi;
using Microsoft.ML.GenAI.Phi.Extension;
using Microsoft.ML.Tokenizers;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using TorchSharp;
Expand All @@ -20,8 +23,10 @@ public static async Task RunChatCompletionSample()
torch.manual_seed(1);
torch.set_default_dtype(defaultType);
var weightFolder = @"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct";
var pipeline = Utils.LoadPhi3Mini4KFromFolder(weightFolder, device: device);

var tokenizerPath = Path.Combine(weightFolder, "tokenizer.model");
var tokenizer = Phi3TokenizerHelper.FromPretrained(tokenizerPath);
var model = Phi3ForCasualLM.FromPretrained(weightFolder, "config.json", layersOnTargetDevice: -1, quantizeToInt8: true);
var pipeline = new CausalLMPipeline<LlamaTokenizer, Phi3ForCasualLM>(tokenizer, model, device);

var kernel = Kernel.CreateBuilder()
.AddGenAIChatCompletion(pipeline)
Expand Down Expand Up @@ -49,8 +54,10 @@ public static async Task RunTextGenerationSample()
torch.manual_seed(1);
torch.set_default_dtype(defaultType);
var weightFolder = @"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct";
var pipeline = Utils.LoadPhi3Mini4KFromFolder(weightFolder, device);

var tokenizerPath = Path.Combine(weightFolder, "tokenizer.model");
var tokenizer = Phi3TokenizerHelper.FromPretrained(tokenizerPath);
var model = Phi3ForCasualLM.FromPretrained(weightFolder, "config.json", layersOnTargetDevice: -1, quantizeToInt8: true);
var pipeline = new CausalLMPipeline<LlamaTokenizer, Phi3ForCasualLM>(tokenizer, model, device);

var kernel = Kernel.CreateBuilder()
.AddGenAITextGeneration(pipeline)
Expand Down
103 changes: 0 additions & 103 deletions docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs

This file was deleted.

2 changes: 1 addition & 1 deletion docs/samples/Microsoft.ML.GenAI.Samples/Program.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// See https://aka.ms/new-console-template for more information
using Microsoft.ML.GenAI.Samples.Phi3Mini;

await SemanticKernelSample.RunChatCompletionSample();
await AutoGenSample.RunAsync();
24 changes: 24 additions & 0 deletions src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ string Generate(
float topP = CausalLMPipeline.Defaults.TopP,
string[]? stopSequences = CausalLMPipeline.Defaults.StopSequence);

/// <summary>
/// Generate the embedding(last hidden state of the last token) for the prompt. The embedding is normalized by L2 norm.
/// </summary>
float[] GenerateEmbeddingFromLastTokenPool(string prompt);

IEnumerable<string> GenerateStreaming(
string prompt,
int maxLen = CausalLMPipeline.Defaults.MaxLen,
Expand Down Expand Up @@ -281,4 +286,23 @@ protected torch.Tensor SampleTopP(torch.Tensor logits, float topP)
nextToken = torch.gather(probsIndex, dim: -1, index: nextToken);
return nextToken;
}

public float[] GenerateEmbeddingFromLastTokenPool(string prompt)
{
using var scope = NewDisposeScope();
using var noGrad = torch.no_grad();
var inputIds = this.Tokenizer.EncodeToIds(prompt);
var inputTensor = torch.tensor(inputIds.ToArray(), dtype: ScalarType.Int64, device: this.Device).unsqueeze(0);
var attentionMask = torch.ones_like(inputTensor, device: this.Device);
var input = new CausalLMModelInput(inputTensor, attentionMask, pastKeyValuesLength: 0);
var output = this.Model.forward(input);
var lastTokenHiddenState = output.LastHiddenState[0, ^1];

// shape of lastTokenHiddenState: [hidden_size]
// L2 norm
var norm = lastTokenHiddenState.norm();
var normalized = lastTokenHiddenState / norm;

return normalized.to_type(ScalarType.Float32).data<float>().ToArray();
}
}
50 changes: 50 additions & 0 deletions src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ForCasualLM.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Text.Json;
using System.Threading.Tasks;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.GenAI.Core.Extension;
using Microsoft.ML.GenAI.Phi.Module;
using TorchSharp;
using TorchSharp.Modules;
Expand Down Expand Up @@ -66,6 +67,55 @@ public static Phi3ForCasualLM FromPretrained(
return phi;
}

public static Phi3ForCasualLM FromPretrained(
string modelFolder,
string configName = "config.json",
string checkPointName = "model.safetensors.index.json",
bool quantizeToInt8 = false,
bool quantizeToInt4 = false,
int layersOnTargetDevice = -1,
ScalarType torchDtype = ScalarType.BFloat16,
string targetDevice = "cuda")
{
if (layersOnTargetDevice == -1 && quantizeToInt4 == false && quantizeToInt8 == false)
{
return FromPretrained(modelFolder, configName, checkPointName, torchDtype, targetDevice);
}

var originalDefaultDevice = torch.get_default_device();
torch.set_default_device("meta");
var config = Path.Join(modelFolder, configName);
var modelConfig = JsonSerializer.Deserialize<Phi3Config>(File.ReadAllText(config)) ?? throw new ArgumentNullException(nameof(config));
modelConfig.DType = torchDtype;
var model = new Phi3ForCasualLM(modelConfig);

if (quantizeToInt8)
{
model.ToInt8QuantizeModule();
}
else if (quantizeToInt4)
{
model.ToInt4QuantizeModule();
}

var deviceMap = model.InferDeviceMapForEachLayer(
[
KeyValuePair.Create(targetDevice, layersOnTargetDevice),
KeyValuePair.Create("cpu", -1)
]);

torch.set_default_device("cpu");
model = new Phi3ForCasualLM(modelConfig);

model.LoadSafeTensors(modelFolder, checkPointName);

model = model.ToDynamicLoadingModel(deviceMap, targetDevice);

torch.set_default_device(originalDefaultDevice);

return model;
}

public void LoadSafeTensors(string modelFolder, string checkPointName = "model.safetensors.index.json")
{
this.load_checkpoint(path: modelFolder, checkpointName: checkPointName, strict: false, useTqdm: false);
Expand Down

0 comments on commit 7c937bf

Please sign in to comment.