Skip to content

Commit

Permalink
Adds the OpenAIMockResponsePlugin (#768)
Browse files Browse the repository at this point in the history
  • Loading branch information
waldekmastykarz authored Jun 6, 2024
1 parent 9e73c68 commit fa465d4
Show file tree
Hide file tree
Showing 15 changed files with 757 additions and 170 deletions.
4 changes: 0 additions & 4 deletions dev-proxy-abstractions/ILanguageModelClient.cs

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

namespace Microsoft.DevProxy.Abstractions;

public interface ILanguageModelChatCompletionMessage
{
string Content { get; set; }
string Role { get; set; }
}
11 changes: 11 additions & 0 deletions dev-proxy-abstractions/LanguageModel/ILanguageModelClient.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

namespace Microsoft.DevProxy.Abstractions;

public interface ILanguageModelClient
{
Task<ILanguageModelCompletionResponse?> GenerateChatCompletion(ILanguageModelChatCompletionMessage[] messages);
Task<ILanguageModelCompletionResponse?> GenerateCompletion(string prompt);
Task<bool> IsEnabled();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

namespace Microsoft.DevProxy.Abstractions;

public interface ILanguageModelCompletionResponse
{
string? Error { get; set; }
string? Response { get; set; }
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

namespace Microsoft.DevProxy.LanguageModel;
namespace Microsoft.DevProxy.Abstractions;

public class LanguageModelConfiguration
{
Expand Down
268 changes: 268 additions & 0 deletions dev-proxy-abstractions/LanguageModel/OllamaLanguageModelClient.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

using System.Diagnostics;
using System.Net.Http.Json;
using Microsoft.Extensions.Logging;

namespace Microsoft.DevProxy.Abstractions;

public class OllamaLanguageModelClient(LanguageModelConfiguration? configuration, ILogger logger) : ILanguageModelClient
{
private readonly LanguageModelConfiguration? _configuration = configuration;
private readonly ILogger _logger = logger;
private bool? _lmAvailable;
private Dictionary<string, OllamaLanguageModelCompletionResponse> _cacheCompletion = new();
private Dictionary<ILanguageModelChatCompletionMessage[], OllamaLanguageModelChatCompletionResponse> _cacheChatCompletion = new();

public async Task<bool> IsEnabled()
{
if (_lmAvailable.HasValue)
{
return _lmAvailable.Value;
}

_lmAvailable = await IsEnabledInternal();
return _lmAvailable.Value;
}

private async Task<bool> IsEnabledInternal()
{
if (_configuration is null || !_configuration.Enabled)
{
return false;
}

if (string.IsNullOrEmpty(_configuration.Url))
{
_logger.LogError("URL is not set. Language model will be disabled");
return false;
}

if (string.IsNullOrEmpty(_configuration.Model))
{
_logger.LogError("Model is not set. Language model will be disabled");
return false;
}

_logger.LogDebug("Checking LM availability at {url}...", _configuration.Url);

try
{
// check if lm is on
using var client = new HttpClient();
var response = await client.GetAsync(_configuration.Url);
_logger.LogDebug("Response: {response}", response.StatusCode);

if (!response.IsSuccessStatusCode)
{
return false;
}

var testCompletion = await GenerateCompletionInternal("Are you there? Reply with a yes or no.");
if (testCompletion?.Error is not null)
{
_logger.LogError("Error: {error}", testCompletion.Error);
return false;
}

return true;
}
catch (Exception ex)
{
_logger.LogError(ex, "Couldn't reach language model at {url}", _configuration.Url);
return false;
}
}

public async Task<ILanguageModelCompletionResponse?> GenerateCompletion(string prompt)
{
using var scope = _logger.BeginScope(nameof(OllamaLanguageModelClient));

if (_configuration is null)
{
return null;
}

if (!_lmAvailable.HasValue)
{
_logger.LogError("Language model availability is not checked. Call {isEnabled} first.", nameof(IsEnabled));
return null;
}

if (!_lmAvailable.Value)
{
return null;
}

if (_configuration.CacheResponses && _cacheCompletion.TryGetValue(prompt, out var cachedResponse))
{
_logger.LogDebug("Returning cached response for prompt: {prompt}", prompt);
return cachedResponse;
}

var response = await GenerateCompletionInternal(prompt);
if (response == null)
{
return null;
}
if (response.Error is not null)
{
_logger.LogError(response.Error);
return null;
}
else
{
if (_configuration.CacheResponses && response.Response is not null)
{
_cacheCompletion[prompt] = response;
}

return response;
}
}

private async Task<OllamaLanguageModelCompletionResponse?> GenerateCompletionInternal(string prompt)
{
Debug.Assert(_configuration != null, "Configuration is null");

try
{
using var client = new HttpClient();
var url = $"{_configuration.Url}/api/generate";
_logger.LogDebug("Requesting completion. Prompt: {prompt}", prompt);

var response = await client.PostAsJsonAsync(url,
new
{
prompt,
model = _configuration.Model,
stream = false
}
);
_logger.LogDebug("Response: {response}", response.StatusCode);

var res = await response.Content.ReadFromJsonAsync<OllamaLanguageModelCompletionResponse>();
if (res is null)
{
return res;
}

res.RequestUrl = url;
return res;
}
catch (Exception ex)
{
_logger.LogError(ex, "Failed to generate completion");
return null;
}
}

public async Task<ILanguageModelCompletionResponse?> GenerateChatCompletion(ILanguageModelChatCompletionMessage[] messages)
{
using var scope = _logger.BeginScope(nameof(OllamaLanguageModelClient));

if (_configuration is null)
{
return null;
}

if (!_lmAvailable.HasValue)
{
_logger.LogError("Language model availability is not checked. Call {isEnabled} first.", nameof(IsEnabled));
return null;
}

if (!_lmAvailable.Value)
{
return null;
}

if (_configuration.CacheResponses && _cacheChatCompletion.TryGetValue(messages, out var cachedResponse))
{
_logger.LogDebug("Returning cached response for message: {lastMessage}", messages.Last().Content);
return cachedResponse;
}

var response = await GenerateChatCompletionInternal(messages);
if (response == null)
{
return null;
}
if (response.Error is not null)
{
_logger.LogError(response.Error);
return null;
}
else
{
if (_configuration.CacheResponses && response.Response is not null)
{
_cacheChatCompletion[messages] = response;
}

return response;
}
}

private async Task<OllamaLanguageModelChatCompletionResponse?> GenerateChatCompletionInternal(ILanguageModelChatCompletionMessage[] messages)
{
Debug.Assert(_configuration != null, "Configuration is null");

try
{
using var client = new HttpClient();
var url = $"{_configuration.Url}/api/chat";
_logger.LogDebug("Requesting chat completion. Message: {lastMessage}", messages.Last().Content);

var response = await client.PostAsJsonAsync(url,
new
{
messages,
model = _configuration.Model,
stream = false
}
);
_logger.LogDebug("Response: {response}", response.StatusCode);

var res = await response.Content.ReadFromJsonAsync<OllamaLanguageModelChatCompletionResponse>();
if (res is null)
{
return res;
}

res.RequestUrl = url;
return res;
}
catch (Exception ex)
{
_logger.LogError(ex, "Failed to generate chat completion");
return null;
}
}
}

internal static class CacheChatCompletionExtensions
{
public static OllamaLanguageModelChatCompletionMessage[]? GetKey(
this Dictionary<OllamaLanguageModelChatCompletionMessage[], OllamaLanguageModelChatCompletionResponse> cache,
ILanguageModelChatCompletionMessage[] messages)
{
return cache.Keys.FirstOrDefault(k => k.SequenceEqual(messages));
}

public static bool TryGetValue(
this Dictionary<OllamaLanguageModelChatCompletionMessage[], OllamaLanguageModelChatCompletionResponse> cache,
ILanguageModelChatCompletionMessage[] messages, out OllamaLanguageModelChatCompletionResponse? value)
{
var key = cache.GetKey(messages);
if (key is null)
{
value = null;
return false;
}

value = cache[key];
return true;
}
}
75 changes: 75 additions & 0 deletions dev-proxy-abstractions/LanguageModel/OllamaModels.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

using System.Text.Json.Serialization;

namespace Microsoft.DevProxy.Abstractions;

public abstract class OllamaResponse : ILanguageModelCompletionResponse
{
[JsonPropertyName("created_at")]
public DateTime CreatedAt { get; set; } = DateTime.MinValue;
public bool Done { get; set; } = false;
public string? Error { get; set; }
[JsonPropertyName("eval_count")]
public long EvalCount { get; set; }
[JsonPropertyName("eval_duration")]
public long EvalDuration { get; set; }
[JsonPropertyName("load_duration")]
public long LoadDuration { get; set; }
public string Model { get; set; } = string.Empty;
[JsonPropertyName("prompt_eval_count")]
public long PromptEvalCount { get; set; }
[JsonPropertyName("prompt_eval_duration")]
public long PromptEvalDuration { get; set; }
public virtual string? Response { get; set; }
[JsonPropertyName("total_duration")]
public long TotalDuration { get; set; }
// custom property added to log in the mock output
public string RequestUrl { get; set; } = string.Empty;
}

public class OllamaLanguageModelCompletionResponse : OllamaResponse
{
public int[] Context { get; set; } = [];
}

public class OllamaLanguageModelChatCompletionResponse : OllamaResponse
{
public OllamaLanguageModelChatCompletionMessage Message { get; set; } = new();
public override string? Response
{
get => Message.Content;
set
{
if (value is null)
{
return;
}

Message = new() { Content = value };
}
}
}

public class OllamaLanguageModelChatCompletionMessage : ILanguageModelChatCompletionMessage
{
public string Content { get; set; } = string.Empty;
public string Role { get; set; } = string.Empty;

public override bool Equals(object? obj)
{
if (obj is null || GetType() != obj.GetType())
{
return false;
}

OllamaLanguageModelChatCompletionMessage m = (OllamaLanguageModelChatCompletionMessage)obj;
return Content == m.Content && Role == m.Role;
}

public override int GetHashCode()
{
return HashCode.Combine(Content, Role);
}
}
Loading

0 comments on commit fa465d4

Please sign in to comment.