Skip to content

Commit

Permalink
Merge pull request #1224 from solliancenet/cp-vectorization-model-fix
Browse files Browse the repository at this point in the history
Send model name to Embedding call if present
  • Loading branch information
ciprianjichici authored Jul 11, 2024
2 parents 3f7dd76 + a6c98ad commit b42cd43
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/dotnet/Vectorization/Handlers/EmbeddingHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,17 @@ protected override async Task<bool> ProcessRequest(
{
var serviceFactory = _serviceProvider.GetService<IVectorizationServiceFactory<ITextEmbeddingService>>()
?? throw new VectorizationException($"Could not retrieve the text embedding service factory instance.");
var textEmbedding = serviceFactory.GetService(_parameters["text_embedding_profile_name"]);
var (textEmbeddingService, textEmbeddingProfileResourceBase) = serviceFactory.GetServiceWithResource(_parameters["text_embedding_profile_name"]);
var textEmbeddingProfile = textEmbeddingProfileResourceBase as TextEmbeddingProfile;
var embeddingModelName = textEmbeddingProfile!.Settings?.TryGetValue("model_name", out var modelName) == true ? modelName : null;

var embeddingResult = default(TextEmbeddingResult);

if (request.RunningOperations.TryGetValue(_stepId, out var runningOperation))
{
// We have an ongoing operation, so we need to attempt to retrieve the emebdding results

embeddingResult = await textEmbedding.GetEmbeddingsAsync(runningOperation.OperationId);
embeddingResult = await textEmbeddingService.GetEmbeddingsAsync(runningOperation.OperationId);

runningOperation.LastResponseTime = DateTime.UtcNow;
runningOperation.PollingCount++;
Expand Down Expand Up @@ -87,13 +89,14 @@ protected override async Task<bool> ProcessRequest(
return false;
}

embeddingResult = await textEmbedding.GetEmbeddingsAsync(
embeddingResult = await textEmbeddingService.GetEmbeddingsAsync(
textPartitioningArtifacts.Select(tpa => new TextChunk
{
Position = tpa.Position,
Content = tpa.Content!,
TokensCount = tpa.Size
}).ToList());
}).ToList(),
embeddingModelName);

if (embeddingResult.InProgress)
{
Expand Down

0 comments on commit b42cd43

Please sign in to comment.