-
Notifications
You must be signed in to change notification settings - Fork 64
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds the OpenAIMockResponsePlugin (#768)
- Loading branch information
1 parent
9e73c68
commit fa465d4
Showing
15 changed files
with
757 additions
and
170 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
10 changes: 10 additions & 0 deletions
10
dev-proxy-abstractions/LanguageModel/ILanguageModelChatCompletionMessage.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
// Copyright (c) Microsoft Corporation. | ||
// Licensed under the MIT License. | ||
|
||
namespace Microsoft.DevProxy.Abstractions; | ||
|
||
public interface ILanguageModelChatCompletionMessage | ||
{ | ||
string Content { get; set; } | ||
string Role { get; set; } | ||
} |
11 changes: 11 additions & 0 deletions
11
dev-proxy-abstractions/LanguageModel/ILanguageModelClient.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
// Copyright (c) Microsoft Corporation. | ||
// Licensed under the MIT License. | ||
|
||
namespace Microsoft.DevProxy.Abstractions; | ||
|
||
public interface ILanguageModelClient | ||
{ | ||
Task<ILanguageModelCompletionResponse?> GenerateChatCompletion(ILanguageModelChatCompletionMessage[] messages); | ||
Task<ILanguageModelCompletionResponse?> GenerateCompletion(string prompt); | ||
Task<bool> IsEnabled(); | ||
} |
10 changes: 10 additions & 0 deletions
10
dev-proxy-abstractions/LanguageModel/ILanguageModelCompletionResponse.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
// Copyright (c) Microsoft Corporation. | ||
// Licensed under the MIT License. | ||
|
||
namespace Microsoft.DevProxy.Abstractions; | ||
|
||
public interface ILanguageModelCompletionResponse | ||
{ | ||
string? Error { get; set; } | ||
string? Response { get; set; } | ||
} |
2 changes: 1 addition & 1 deletion
2
...nguageModel/LanguageModelConfiguration.cs → ...nguageModel/LanguageModelConfiguration.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
268 changes: 268 additions & 0 deletions
268
dev-proxy-abstractions/LanguageModel/OllamaLanguageModelClient.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,268 @@ | ||
// Copyright (c) Microsoft Corporation. | ||
// Licensed under the MIT License. | ||
|
||
using System.Diagnostics; | ||
using System.Net.Http.Json; | ||
using Microsoft.Extensions.Logging; | ||
|
||
namespace Microsoft.DevProxy.Abstractions; | ||
|
||
public class OllamaLanguageModelClient(LanguageModelConfiguration? configuration, ILogger logger) : ILanguageModelClient | ||
{ | ||
private readonly LanguageModelConfiguration? _configuration = configuration; | ||
private readonly ILogger _logger = logger; | ||
private bool? _lmAvailable; | ||
private Dictionary<string, OllamaLanguageModelCompletionResponse> _cacheCompletion = new(); | ||
private Dictionary<ILanguageModelChatCompletionMessage[], OllamaLanguageModelChatCompletionResponse> _cacheChatCompletion = new(); | ||
|
||
public async Task<bool> IsEnabled() | ||
{ | ||
if (_lmAvailable.HasValue) | ||
{ | ||
return _lmAvailable.Value; | ||
} | ||
|
||
_lmAvailable = await IsEnabledInternal(); | ||
return _lmAvailable.Value; | ||
} | ||
|
||
private async Task<bool> IsEnabledInternal() | ||
{ | ||
if (_configuration is null || !_configuration.Enabled) | ||
{ | ||
return false; | ||
} | ||
|
||
if (string.IsNullOrEmpty(_configuration.Url)) | ||
{ | ||
_logger.LogError("URL is not set. Language model will be disabled"); | ||
return false; | ||
} | ||
|
||
if (string.IsNullOrEmpty(_configuration.Model)) | ||
{ | ||
_logger.LogError("Model is not set. Language model will be disabled"); | ||
return false; | ||
} | ||
|
||
_logger.LogDebug("Checking LM availability at {url}...", _configuration.Url); | ||
|
||
try | ||
{ | ||
// check if lm is on | ||
using var client = new HttpClient(); | ||
var response = await client.GetAsync(_configuration.Url); | ||
_logger.LogDebug("Response: {response}", response.StatusCode); | ||
|
||
if (!response.IsSuccessStatusCode) | ||
{ | ||
return false; | ||
} | ||
|
||
var testCompletion = await GenerateCompletionInternal("Are you there? Reply with a yes or no."); | ||
if (testCompletion?.Error is not null) | ||
{ | ||
_logger.LogError("Error: {error}", testCompletion.Error); | ||
return false; | ||
} | ||
|
||
return true; | ||
} | ||
catch (Exception ex) | ||
{ | ||
_logger.LogError(ex, "Couldn't reach language model at {url}", _configuration.Url); | ||
return false; | ||
} | ||
} | ||
|
||
public async Task<ILanguageModelCompletionResponse?> GenerateCompletion(string prompt) | ||
{ | ||
using var scope = _logger.BeginScope(nameof(OllamaLanguageModelClient)); | ||
|
||
if (_configuration is null) | ||
{ | ||
return null; | ||
} | ||
|
||
if (!_lmAvailable.HasValue) | ||
{ | ||
_logger.LogError("Language model availability is not checked. Call {isEnabled} first.", nameof(IsEnabled)); | ||
return null; | ||
} | ||
|
||
if (!_lmAvailable.Value) | ||
{ | ||
return null; | ||
} | ||
|
||
if (_configuration.CacheResponses && _cacheCompletion.TryGetValue(prompt, out var cachedResponse)) | ||
{ | ||
_logger.LogDebug("Returning cached response for prompt: {prompt}", prompt); | ||
return cachedResponse; | ||
} | ||
|
||
var response = await GenerateCompletionInternal(prompt); | ||
if (response == null) | ||
{ | ||
return null; | ||
} | ||
if (response.Error is not null) | ||
{ | ||
_logger.LogError(response.Error); | ||
return null; | ||
} | ||
else | ||
{ | ||
if (_configuration.CacheResponses && response.Response is not null) | ||
{ | ||
_cacheCompletion[prompt] = response; | ||
} | ||
|
||
return response; | ||
} | ||
} | ||
|
||
private async Task<OllamaLanguageModelCompletionResponse?> GenerateCompletionInternal(string prompt) | ||
{ | ||
Debug.Assert(_configuration != null, "Configuration is null"); | ||
|
||
try | ||
{ | ||
using var client = new HttpClient(); | ||
var url = $"{_configuration.Url}/api/generate"; | ||
_logger.LogDebug("Requesting completion. Prompt: {prompt}", prompt); | ||
|
||
var response = await client.PostAsJsonAsync(url, | ||
new | ||
{ | ||
prompt, | ||
model = _configuration.Model, | ||
stream = false | ||
} | ||
); | ||
_logger.LogDebug("Response: {response}", response.StatusCode); | ||
|
||
var res = await response.Content.ReadFromJsonAsync<OllamaLanguageModelCompletionResponse>(); | ||
if (res is null) | ||
{ | ||
return res; | ||
} | ||
|
||
res.RequestUrl = url; | ||
return res; | ||
} | ||
catch (Exception ex) | ||
{ | ||
_logger.LogError(ex, "Failed to generate completion"); | ||
return null; | ||
} | ||
} | ||
|
||
public async Task<ILanguageModelCompletionResponse?> GenerateChatCompletion(ILanguageModelChatCompletionMessage[] messages) | ||
{ | ||
using var scope = _logger.BeginScope(nameof(OllamaLanguageModelClient)); | ||
|
||
if (_configuration is null) | ||
{ | ||
return null; | ||
} | ||
|
||
if (!_lmAvailable.HasValue) | ||
{ | ||
_logger.LogError("Language model availability is not checked. Call {isEnabled} first.", nameof(IsEnabled)); | ||
return null; | ||
} | ||
|
||
if (!_lmAvailable.Value) | ||
{ | ||
return null; | ||
} | ||
|
||
if (_configuration.CacheResponses && _cacheChatCompletion.TryGetValue(messages, out var cachedResponse)) | ||
{ | ||
_logger.LogDebug("Returning cached response for message: {lastMessage}", messages.Last().Content); | ||
return cachedResponse; | ||
} | ||
|
||
var response = await GenerateChatCompletionInternal(messages); | ||
if (response == null) | ||
{ | ||
return null; | ||
} | ||
if (response.Error is not null) | ||
{ | ||
_logger.LogError(response.Error); | ||
return null; | ||
} | ||
else | ||
{ | ||
if (_configuration.CacheResponses && response.Response is not null) | ||
{ | ||
_cacheChatCompletion[messages] = response; | ||
} | ||
|
||
return response; | ||
} | ||
} | ||
|
||
private async Task<OllamaLanguageModelChatCompletionResponse?> GenerateChatCompletionInternal(ILanguageModelChatCompletionMessage[] messages) | ||
{ | ||
Debug.Assert(_configuration != null, "Configuration is null"); | ||
|
||
try | ||
{ | ||
using var client = new HttpClient(); | ||
var url = $"{_configuration.Url}/api/chat"; | ||
_logger.LogDebug("Requesting chat completion. Message: {lastMessage}", messages.Last().Content); | ||
|
||
var response = await client.PostAsJsonAsync(url, | ||
new | ||
{ | ||
messages, | ||
model = _configuration.Model, | ||
stream = false | ||
} | ||
); | ||
_logger.LogDebug("Response: {response}", response.StatusCode); | ||
|
||
var res = await response.Content.ReadFromJsonAsync<OllamaLanguageModelChatCompletionResponse>(); | ||
if (res is null) | ||
{ | ||
return res; | ||
} | ||
|
||
res.RequestUrl = url; | ||
return res; | ||
} | ||
catch (Exception ex) | ||
{ | ||
_logger.LogError(ex, "Failed to generate chat completion"); | ||
return null; | ||
} | ||
} | ||
} | ||
|
||
internal static class CacheChatCompletionExtensions | ||
{ | ||
public static OllamaLanguageModelChatCompletionMessage[]? GetKey( | ||
this Dictionary<OllamaLanguageModelChatCompletionMessage[], OllamaLanguageModelChatCompletionResponse> cache, | ||
ILanguageModelChatCompletionMessage[] messages) | ||
{ | ||
return cache.Keys.FirstOrDefault(k => k.SequenceEqual(messages)); | ||
} | ||
|
||
public static bool TryGetValue( | ||
this Dictionary<OllamaLanguageModelChatCompletionMessage[], OllamaLanguageModelChatCompletionResponse> cache, | ||
ILanguageModelChatCompletionMessage[] messages, out OllamaLanguageModelChatCompletionResponse? value) | ||
{ | ||
var key = cache.GetKey(messages); | ||
if (key is null) | ||
{ | ||
value = null; | ||
return false; | ||
} | ||
|
||
value = cache[key]; | ||
return true; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
// Copyright (c) Microsoft Corporation. | ||
// Licensed under the MIT License. | ||
|
||
using System.Text.Json.Serialization; | ||
|
||
namespace Microsoft.DevProxy.Abstractions; | ||
|
||
public abstract class OllamaResponse : ILanguageModelCompletionResponse | ||
{ | ||
[JsonPropertyName("created_at")] | ||
public DateTime CreatedAt { get; set; } = DateTime.MinValue; | ||
public bool Done { get; set; } = false; | ||
public string? Error { get; set; } | ||
[JsonPropertyName("eval_count")] | ||
public long EvalCount { get; set; } | ||
[JsonPropertyName("eval_duration")] | ||
public long EvalDuration { get; set; } | ||
[JsonPropertyName("load_duration")] | ||
public long LoadDuration { get; set; } | ||
public string Model { get; set; } = string.Empty; | ||
[JsonPropertyName("prompt_eval_count")] | ||
public long PromptEvalCount { get; set; } | ||
[JsonPropertyName("prompt_eval_duration")] | ||
public long PromptEvalDuration { get; set; } | ||
public virtual string? Response { get; set; } | ||
[JsonPropertyName("total_duration")] | ||
public long TotalDuration { get; set; } | ||
// custom property added to log in the mock output | ||
public string RequestUrl { get; set; } = string.Empty; | ||
} | ||
|
||
public class OllamaLanguageModelCompletionResponse : OllamaResponse | ||
{ | ||
public int[] Context { get; set; } = []; | ||
} | ||
|
||
public class OllamaLanguageModelChatCompletionResponse : OllamaResponse | ||
{ | ||
public OllamaLanguageModelChatCompletionMessage Message { get; set; } = new(); | ||
public override string? Response | ||
{ | ||
get => Message.Content; | ||
set | ||
{ | ||
if (value is null) | ||
{ | ||
return; | ||
} | ||
|
||
Message = new() { Content = value }; | ||
} | ||
} | ||
} | ||
|
||
public class OllamaLanguageModelChatCompletionMessage : ILanguageModelChatCompletionMessage | ||
{ | ||
public string Content { get; set; } = string.Empty; | ||
public string Role { get; set; } = string.Empty; | ||
|
||
public override bool Equals(object? obj) | ||
{ | ||
if (obj is null || GetType() != obj.GetType()) | ||
{ | ||
return false; | ||
} | ||
|
||
OllamaLanguageModelChatCompletionMessage m = (OllamaLanguageModelChatCompletionMessage)obj; | ||
return Content == m.Content && Role == m.Role; | ||
} | ||
|
||
public override int GetHashCode() | ||
{ | ||
return HashCode.Combine(Content, Role); | ||
} | ||
} |
Oops, something went wrong.