diff --git a/dev-proxy-abstractions/ILanguageModelClient.cs b/dev-proxy-abstractions/ILanguageModelClient.cs deleted file mode 100644 index 16e343a5..00000000 --- a/dev-proxy-abstractions/ILanguageModelClient.cs +++ /dev/null @@ -1,4 +0,0 @@ -public interface ILanguageModelClient -{ - Task GenerateCompletion(string prompt); -} \ No newline at end of file diff --git a/dev-proxy-abstractions/LanguageModel/ILanguageModelChatCompletionMessage.cs b/dev-proxy-abstractions/LanguageModel/ILanguageModelChatCompletionMessage.cs new file mode 100644 index 00000000..d1e9bcd1 --- /dev/null +++ b/dev-proxy-abstractions/LanguageModel/ILanguageModelChatCompletionMessage.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Microsoft.DevProxy.Abstractions; + +public interface ILanguageModelChatCompletionMessage +{ + string Content { get; set; } + string Role { get; set; } +} \ No newline at end of file diff --git a/dev-proxy-abstractions/LanguageModel/ILanguageModelClient.cs b/dev-proxy-abstractions/LanguageModel/ILanguageModelClient.cs new file mode 100644 index 00000000..b3d8f433 --- /dev/null +++ b/dev-proxy-abstractions/LanguageModel/ILanguageModelClient.cs @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Microsoft.DevProxy.Abstractions; + +public interface ILanguageModelClient +{ + Task GenerateChatCompletion(ILanguageModelChatCompletionMessage[] messages); + Task GenerateCompletion(string prompt); + Task IsEnabled(); +} \ No newline at end of file diff --git a/dev-proxy-abstractions/LanguageModel/ILanguageModelCompletionResponse.cs b/dev-proxy-abstractions/LanguageModel/ILanguageModelCompletionResponse.cs new file mode 100644 index 00000000..76bcfbfd --- /dev/null +++ b/dev-proxy-abstractions/LanguageModel/ILanguageModelCompletionResponse.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Microsoft.DevProxy.Abstractions; + +public interface ILanguageModelCompletionResponse +{ + string? Error { get; set; } + string? Response { get; set; } +} \ No newline at end of file diff --git a/dev-proxy/LanguageModel/LanguageModelConfiguration.cs b/dev-proxy-abstractions/LanguageModel/LanguageModelConfiguration.cs similarity index 89% rename from dev-proxy/LanguageModel/LanguageModelConfiguration.cs rename to dev-proxy-abstractions/LanguageModel/LanguageModelConfiguration.cs index c6bb81ae..594d0155 100644 --- a/dev-proxy/LanguageModel/LanguageModelConfiguration.cs +++ b/dev-proxy-abstractions/LanguageModel/LanguageModelConfiguration.cs @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -namespace Microsoft.DevProxy.LanguageModel; +namespace Microsoft.DevProxy.Abstractions; public class LanguageModelConfiguration { diff --git a/dev-proxy-abstractions/LanguageModel/OllamaLanguageModelClient.cs b/dev-proxy-abstractions/LanguageModel/OllamaLanguageModelClient.cs new file mode 100644 index 00000000..91fd9f28 --- /dev/null +++ b/dev-proxy-abstractions/LanguageModel/OllamaLanguageModelClient.cs @@ -0,0 +1,268 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Diagnostics; +using System.Net.Http.Json; +using Microsoft.Extensions.Logging; + +namespace Microsoft.DevProxy.Abstractions; + +public class OllamaLanguageModelClient(LanguageModelConfiguration? configuration, ILogger logger) : ILanguageModelClient +{ + private readonly LanguageModelConfiguration? _configuration = configuration; + private readonly ILogger _logger = logger; + private bool? _lmAvailable; + private Dictionary _cacheCompletion = new(); + private Dictionary _cacheChatCompletion = new(); + + public async Task IsEnabled() + { + if (_lmAvailable.HasValue) + { + return _lmAvailable.Value; + } + + _lmAvailable = await IsEnabledInternal(); + return _lmAvailable.Value; + } + + private async Task IsEnabledInternal() + { + if (_configuration is null || !_configuration.Enabled) + { + return false; + } + + if (string.IsNullOrEmpty(_configuration.Url)) + { + _logger.LogError("URL is not set. Language model will be disabled"); + return false; + } + + if (string.IsNullOrEmpty(_configuration.Model)) + { + _logger.LogError("Model is not set. Language model will be disabled"); + return false; + } + + _logger.LogDebug("Checking LM availability at {url}...", _configuration.Url); + + try + { + // check if lm is on + using var client = new HttpClient(); + var response = await client.GetAsync(_configuration.Url); + _logger.LogDebug("Response: {response}", response.StatusCode); + + if (!response.IsSuccessStatusCode) + { + return false; + } + + var testCompletion = await GenerateCompletionInternal("Are you there? Reply with a yes or no."); + if (testCompletion?.Error is not null) + { + _logger.LogError("Error: {error}", testCompletion.Error); + return false; + } + + return true; + } + catch (Exception ex) + { + _logger.LogError(ex, "Couldn't reach language model at {url}", _configuration.Url); + return false; + } + } + + public async Task GenerateCompletion(string prompt) + { + using var scope = _logger.BeginScope(nameof(OllamaLanguageModelClient)); + + if (_configuration is null) + { + return null; + } + + if (!_lmAvailable.HasValue) + { + _logger.LogError("Language model availability is not checked. Call {isEnabled} first.", nameof(IsEnabled)); + return null; + } + + if (!_lmAvailable.Value) + { + return null; + } + + if (_configuration.CacheResponses && _cacheCompletion.TryGetValue(prompt, out var cachedResponse)) + { + _logger.LogDebug("Returning cached response for prompt: {prompt}", prompt); + return cachedResponse; + } + + var response = await GenerateCompletionInternal(prompt); + if (response == null) + { + return null; + } + if (response.Error is not null) + { + _logger.LogError(response.Error); + return null; + } + else + { + if (_configuration.CacheResponses && response.Response is not null) + { + _cacheCompletion[prompt] = response; + } + + return response; + } + } + + private async Task GenerateCompletionInternal(string prompt) + { + Debug.Assert(_configuration != null, "Configuration is null"); + + try + { + using var client = new HttpClient(); + var url = $"{_configuration.Url}/api/generate"; + _logger.LogDebug("Requesting completion. Prompt: {prompt}", prompt); + + var response = await client.PostAsJsonAsync(url, + new + { + prompt, + model = _configuration.Model, + stream = false + } + ); + _logger.LogDebug("Response: {response}", response.StatusCode); + + var res = await response.Content.ReadFromJsonAsync(); + if (res is null) + { + return res; + } + + res.RequestUrl = url; + return res; + } + catch (Exception ex) + { + _logger.LogError(ex, "Failed to generate completion"); + return null; + } + } + + public async Task GenerateChatCompletion(ILanguageModelChatCompletionMessage[] messages) + { + using var scope = _logger.BeginScope(nameof(OllamaLanguageModelClient)); + + if (_configuration is null) + { + return null; + } + + if (!_lmAvailable.HasValue) + { + _logger.LogError("Language model availability is not checked. Call {isEnabled} first.", nameof(IsEnabled)); + return null; + } + + if (!_lmAvailable.Value) + { + return null; + } + + if (_configuration.CacheResponses && _cacheChatCompletion.TryGetValue(messages, out var cachedResponse)) + { + _logger.LogDebug("Returning cached response for message: {lastMessage}", messages.Last().Content); + return cachedResponse; + } + + var response = await GenerateChatCompletionInternal(messages); + if (response == null) + { + return null; + } + if (response.Error is not null) + { + _logger.LogError(response.Error); + return null; + } + else + { + if (_configuration.CacheResponses && response.Response is not null) + { + _cacheChatCompletion[messages] = response; + } + + return response; + } + } + + private async Task GenerateChatCompletionInternal(ILanguageModelChatCompletionMessage[] messages) + { + Debug.Assert(_configuration != null, "Configuration is null"); + + try + { + using var client = new HttpClient(); + var url = $"{_configuration.Url}/api/chat"; + _logger.LogDebug("Requesting chat completion. Message: {lastMessage}", messages.Last().Content); + + var response = await client.PostAsJsonAsync(url, + new + { + messages, + model = _configuration.Model, + stream = false + } + ); + _logger.LogDebug("Response: {response}", response.StatusCode); + + var res = await response.Content.ReadFromJsonAsync(); + if (res is null) + { + return res; + } + + res.RequestUrl = url; + return res; + } + catch (Exception ex) + { + _logger.LogError(ex, "Failed to generate chat completion"); + return null; + } + } +} + +internal static class CacheChatCompletionExtensions +{ + public static OllamaLanguageModelChatCompletionMessage[]? GetKey( + this Dictionary cache, + ILanguageModelChatCompletionMessage[] messages) + { + return cache.Keys.FirstOrDefault(k => k.SequenceEqual(messages)); + } + + public static bool TryGetValue( + this Dictionary cache, + ILanguageModelChatCompletionMessage[] messages, out OllamaLanguageModelChatCompletionResponse? value) + { + var key = cache.GetKey(messages); + if (key is null) + { + value = null; + return false; + } + + value = cache[key]; + return true; + } +} \ No newline at end of file diff --git a/dev-proxy-abstractions/LanguageModel/OllamaModels.cs b/dev-proxy-abstractions/LanguageModel/OllamaModels.cs new file mode 100644 index 00000000..178afffd --- /dev/null +++ b/dev-proxy-abstractions/LanguageModel/OllamaModels.cs @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json.Serialization; + +namespace Microsoft.DevProxy.Abstractions; + +public abstract class OllamaResponse : ILanguageModelCompletionResponse +{ + [JsonPropertyName("created_at")] + public DateTime CreatedAt { get; set; } = DateTime.MinValue; + public bool Done { get; set; } = false; + public string? Error { get; set; } + [JsonPropertyName("eval_count")] + public long EvalCount { get; set; } + [JsonPropertyName("eval_duration")] + public long EvalDuration { get; set; } + [JsonPropertyName("load_duration")] + public long LoadDuration { get; set; } + public string Model { get; set; } = string.Empty; + [JsonPropertyName("prompt_eval_count")] + public long PromptEvalCount { get; set; } + [JsonPropertyName("prompt_eval_duration")] + public long PromptEvalDuration { get; set; } + public virtual string? Response { get; set; } + [JsonPropertyName("total_duration")] + public long TotalDuration { get; set; } + // custom property added to log in the mock output + public string RequestUrl { get; set; } = string.Empty; +} + +public class OllamaLanguageModelCompletionResponse : OllamaResponse +{ + public int[] Context { get; set; } = []; +} + +public class OllamaLanguageModelChatCompletionResponse : OllamaResponse +{ + public OllamaLanguageModelChatCompletionMessage Message { get; set; } = new(); + public override string? Response + { + get => Message.Content; + set + { + if (value is null) + { + return; + } + + Message = new() { Content = value }; + } + } +} + +public class OllamaLanguageModelChatCompletionMessage : ILanguageModelChatCompletionMessage +{ + public string Content { get; set; } = string.Empty; + public string Role { get; set; } = string.Empty; + + public override bool Equals(object? obj) + { + if (obj is null || GetType() != obj.GetType()) + { + return false; + } + + OllamaLanguageModelChatCompletionMessage m = (OllamaLanguageModelChatCompletionMessage)obj; + return Content == m.Content && Role == m.Role; + } + + public override int GetHashCode() + { + return HashCode.Combine(Content, Role); + } +} \ No newline at end of file diff --git a/dev-proxy-plugins/Mocks/OpenAIMockResponsePlugin.cs b/dev-proxy-plugins/Mocks/OpenAIMockResponsePlugin.cs new file mode 100644 index 00000000..3fe8c02b --- /dev/null +++ b/dev-proxy-plugins/Mocks/OpenAIMockResponsePlugin.cs @@ -0,0 +1,373 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Net; +using System.Text.Json; +using System.Text.Json.Serialization; +using Microsoft.DevProxy.Abstractions; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; +using Titanium.Web.Proxy.Models; + +public class OpenAIMockResponsePlugin : BaseProxyPlugin +{ + public OpenAIMockResponsePlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : base(pluginEvents, context, logger, urlsToWatch, configSection) + { + } + + public override string Name => nameof(OpenAIMockResponsePlugin); + + public override void Register() + { + base.Register(); + + using var scope = Logger.BeginScope(Name); + + Logger.LogInformation("Checking language model availability..."); + if (!Context.LanguageModelClient.IsEnabled().Result) + { + Logger.LogError("Local language model is not enabled. The {plugin} will not be used.", Name); + return; + } + + PluginEvents.BeforeRequest += OnRequest; + } + + private async Task OnRequest(object sender, ProxyRequestArgs e) + { + using var scope = Logger.BeginScope(Name); + + var request = e.Session.HttpClient.Request; + if (!request.Method.Equals("POST", StringComparison.OrdinalIgnoreCase) || + !request.HasBody) + { + return; + } + + if (!TryGetOpenAIRequest(request.BodyString, out var openAiRequest)) + { + return; + } + + if (openAiRequest is OpenAICompletionRequest completionRequest) + { + var ollamaResponse = (await Context.LanguageModelClient.GenerateCompletion(completionRequest.Prompt)) + as OllamaLanguageModelCompletionResponse; + if (ollamaResponse is null) + { + return; + } + if (ollamaResponse.Error is not null) + { + Logger.LogError("Error from Ollama language model: {error}", ollamaResponse.Error); + return; + } + + var openAiResponse = ollamaResponse.ConvertToOpenAIResponse(); + SendMockResponse(openAiResponse, ollamaResponse.RequestUrl, e); + } + else if (openAiRequest is OpenAIChatCompletionRequest chatRequest) + { + var ollamaResponse = (await Context.LanguageModelClient + .GenerateChatCompletion(chatRequest.Messages.ConvertToLanguageModelChatCompletionMessage())) + as OllamaLanguageModelChatCompletionResponse; + if (ollamaResponse is null) + { + return; + } + if (ollamaResponse.Error is not null) + { + Logger.LogError("Error from Ollama language model: {error}", ollamaResponse.Error); + return; + } + + var openAiResponse = ollamaResponse.ConvertToOpenAIResponse(); + SendMockResponse(openAiResponse, ollamaResponse.RequestUrl, e); + } + else + { + Logger.LogError("Unknown OpenAI request type."); + } + } + + private bool TryGetOpenAIRequest(string content, out OpenAIRequest? request) + { + request = null; + + if (string.IsNullOrEmpty(content)) + { + return false; + } + + try + { + Logger.LogDebug("Checking if the request is an OpenAI request..."); + + var rawRequest = JsonSerializer.Deserialize(content, ProxyUtils.JsonSerializerOptions); + + if (rawRequest.TryGetProperty("prompt", out _)) + { + Logger.LogDebug("Request is a completion request"); + request = JsonSerializer.Deserialize(content, ProxyUtils.JsonSerializerOptions); + return true; + } + + if (rawRequest.TryGetProperty("messages", out _)) + { + Logger.LogDebug("Request is a chat completion request"); + request = JsonSerializer.Deserialize(content, ProxyUtils.JsonSerializerOptions); + return true; + } + + Logger.LogDebug("Request is not an OpenAI request."); + return false; + } + catch (JsonException ex) + { + Logger.LogDebug(ex, "Failed to deserialize OpenAI request."); + return false; + } + } + + private void SendMockResponse(OpenAIResponse response, string localLmUrl, ProxyRequestArgs e) where TResponse : OpenAIResponse + { + e.Session.GenericResponse( + // we need this cast or else the JsonSerializer drops derived properties + JsonSerializer.Serialize((TResponse)response, ProxyUtils.JsonSerializerOptions), + HttpStatusCode.OK, + [ + new HttpHeader("content-type", "application/json"), + new HttpHeader("access-control-allow-origin", "*") + ] + ); + e.ResponseState.HasBeenSet = true; + Logger.LogRequest([$"200 {localLmUrl}"], MessageType.Mocked, new LoggingContext(e.Session)); + } +} + +#region models + +internal abstract class OpenAIRequest +{ + [JsonPropertyName("frequency_penalty")] + public long FrequencyPenalty { get; set; } + [JsonPropertyName("max_tokens")] + public long MaxTokens { get; set; } + [JsonPropertyName("presence_penalty")] + public long PresencePenalty { get; set; } + public object? Stop { get; set; } + public bool Stream { get; set; } + public long Temperature { get; set; } + [JsonPropertyName("top_p")] + public double TopP { get; set; } +} + +internal abstract class OpenAIResponse +{ + public long Created { get; set; } + public string Id { get; set; } = string.Empty; + public string Model { get; set; } = string.Empty; + public string Object { get; set; } = "text_completion"; + [JsonPropertyName("prompt_filter_results")] + public OpenAIResponsePromptFilterResult[] PromptFilterResults { get; set; } = []; + public OpenAIResponseUsage Usage { get; set; } = new(); +} + +internal abstract class OpenAIResponse : OpenAIResponse +{ + public TChoice[] Choices { get; set; } = []; +} + +internal class OpenAIResponseUsage +{ + [JsonPropertyName("completion_tokens")] + public long CompletionTokens { get; set; } + [JsonPropertyName("prompt_tokens")] + public long PromptTokens { get; set; } + [JsonPropertyName("total_tokens")] + public long TotalTokens { get; set; } +} + +internal abstract class OpenAIResponseChoice +{ + [JsonPropertyName("content_filter_results")] + public Dictionary ContentFilterResults { get; set; } = new(); + [JsonPropertyName("finish_reason")] + public string FinishReason { get; set; } = "length"; + public long Index { get; set; } + [JsonIgnore(Condition = JsonIgnoreCondition.Never)] + public object? Logprobs { get; set; } +} + +internal class OpenAIResponsePromptFilterResult +{ + [JsonPropertyName("content_filter_results")] + public Dictionary ContentFilterResults { get; set; } = new(); + [JsonPropertyName("prompt_index")] + public long PromptIndex { get; set; } +} + +internal class OpenAIResponseContentFilterResult +{ + public bool Filtered { get; set; } + public string Severity { get; set; } = "safe"; +} + +internal class OpenAICompletionRequest : OpenAIRequest +{ + public string Prompt { get; set; } = string.Empty; +} + +internal class OpenAICompletionResponse : OpenAIResponse +{ +} + +internal class OpenAICompletionResponseChoice : OpenAIResponseChoice +{ + public string Text { get; set; } = string.Empty; +} + +internal class OpenAIChatCompletionRequest : OpenAIRequest +{ + public OpenAIChatMessage[] Messages { get; set; } = []; +} + +internal class OpenAIChatMessage +{ + public string Content { get; set; } = string.Empty; + public string Role { get; set; } = string.Empty; +} + +internal class OpenAIChatCompletionResponse : OpenAIResponse +{ +} + +internal class OpenAIChatCompletionResponseChoice : OpenAIResponseChoice +{ + public OpenAIChatCompletionResponseChoiceMessage Message { get; set; } = new(); +} + +internal class OpenAIChatCompletionResponseChoiceMessage +{ + public string Content { get; set; } = string.Empty; + public string Role { get; set; } = string.Empty; +} + +#endregion + +#region extensions + +internal static class OllamaLanguageModelCompletionResponseExtensions +{ + public static OpenAICompletionResponse ConvertToOpenAIResponse(this OllamaLanguageModelCompletionResponse response) + { + return new OpenAICompletionResponse + { + Id = Guid.NewGuid().ToString(), + Object = "text_completion", + Created = ((DateTimeOffset)response.CreatedAt).ToUnixTimeSeconds(), + Model = response.Model, + PromptFilterResults = + [ + new OpenAIResponsePromptFilterResult + { + PromptIndex = 0, + ContentFilterResults = new Dictionary + { + { "hate", new() { Filtered = false, Severity = "safe" } }, + { "self_harm", new() { Filtered = false, Severity = "safe" } }, + { "sexual", new() { Filtered = false, Severity = "safe" } }, + { "violence", new() { Filtered = false, Severity = "safe" } } + } + } + ], + Choices = + [ + new OpenAICompletionResponseChoice + { + Text = response.Response ?? string.Empty, + Index = 0, + FinishReason = "length", + ContentFilterResults = new Dictionary + { + { "hate", new() { Filtered = false, Severity = "safe" } }, + { "self_harm", new() { Filtered = false, Severity = "safe" } }, + { "sexual", new() { Filtered = false, Severity = "safe" } }, + { "violence", new() { Filtered = false, Severity = "safe" } } + } + } + ], + Usage = new OpenAIResponseUsage + { + PromptTokens = response.PromptEvalCount, + CompletionTokens = response.EvalCount, + TotalTokens = response.PromptEvalCount + response.EvalCount + } + }; + } +} + +internal static class OllamaLanguageModelChatCompletionResponseExtensions +{ + public static OpenAIChatCompletionResponse ConvertToOpenAIResponse(this OllamaLanguageModelChatCompletionResponse response) + { + return new OpenAIChatCompletionResponse + { + Choices = [new OpenAIChatCompletionResponseChoice + { + ContentFilterResults = new Dictionary + { + { "hate", new() { Filtered = false, Severity = "safe" } }, + { "self_harm", new() { Filtered = false, Severity = "safe" } }, + { "sexual", new() { Filtered = false, Severity = "safe" } }, + { "violence", new() { Filtered = false, Severity = "safe" } } + }, + FinishReason = "stop", + Index = 0, + Message = new() + { + Content = response.Message.Content, + Role = response.Message.Role + } + }], + Created = ((DateTimeOffset)response.CreatedAt).ToUnixTimeSeconds(), + Id = Guid.NewGuid().ToString(), + Model = response.Model, + Object = "chat.completion", + PromptFilterResults = + [ + new OpenAIResponsePromptFilterResult + { + PromptIndex = 0, + ContentFilterResults = new Dictionary + { + { "hate", new() { Filtered = false, Severity = "safe" } }, + { "self_harm", new() { Filtered = false, Severity = "safe" } }, + { "sexual", new() { Filtered = false, Severity = "safe" } }, + { "violence", new() { Filtered = false, Severity = "safe" } } + } + } + ], + Usage = new OpenAIResponseUsage + { + PromptTokens = response.PromptEvalCount, + CompletionTokens = response.EvalCount, + TotalTokens = response.PromptEvalCount + response.EvalCount + } + }; + } +} + +internal static class OpenAIChatMessageExtensions +{ + public static ILanguageModelChatCompletionMessage[] ConvertToLanguageModelChatCompletionMessage(this OpenAIChatMessage[] messages) + { + return messages.Select(m => new OllamaLanguageModelChatCompletionMessage + { + Content = m.Content, + Role = m.Role + }).ToArray(); + } +} + +#endregion \ No newline at end of file diff --git a/dev-proxy-plugins/RequestLogs/OpenApiSpecGeneratorPlugin.cs b/dev-proxy-plugins/RequestLogs/OpenApiSpecGeneratorPlugin.cs index 0e813584..0285dfd5 100644 --- a/dev-proxy-plugins/RequestLogs/OpenApiSpecGeneratorPlugin.cs +++ b/dev-proxy-plugins/RequestLogs/OpenApiSpecGeneratorPlugin.cs @@ -452,14 +452,14 @@ private async Task GetOperationId(string method, string serverUrl, strin { var prompt = $"For the specified request, generate an operation ID, compatible with an OpenAPI spec. Respond with just the ID in plain-text format. For example, for request such as `GET https://api.contoso.com/books/{{books-id}}` you return `getBookById`. For a request like `GET https://api.contoso.com/books/{{books-id}}/authors` you return `getAuthorsForBookById`. Request: {method.ToUpper()} {serverUrl}{parametrizedPath}"; var id = await Context.LanguageModelClient.GenerateCompletion(prompt); - return id ?? $"{method}{parametrizedPath.Replace('/', '.')}"; + return id?.Response ?? $"{method}{parametrizedPath.Replace('/', '.')}"; } private async Task GetOperationDescription(string method, string serverUrl, string parametrizedPath) { var prompt = $"You're an expert in OpenAPI. You help developers build great OpenAPI specs for use with LLMs. For the specified request, generate a one-sentence description. Respond with just the description. For example, for a request such as `GET https://api.contoso.com/books/{{books-id}}` you return `Get a book by ID`. Request: {method.ToUpper()} {serverUrl}{parametrizedPath}"; var description = await Context.LanguageModelClient.GenerateCompletion(prompt); - return description ?? $"{method} {parametrizedPath}"; + return description?.Response ?? $"{method} {parametrizedPath}"; } /** diff --git a/dev-proxy/LanguageModel/LanguageModelClient.cs b/dev-proxy/LanguageModel/LanguageModelClient.cs deleted file mode 100644 index a020b3f0..00000000 --- a/dev-proxy/LanguageModel/LanguageModelClient.cs +++ /dev/null @@ -1,146 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using System.Diagnostics; -using System.Net.Http.Json; -using Microsoft.Extensions.Logging; - -namespace Microsoft.DevProxy.LanguageModel; - -public class LanguageModelClient(LanguageModelConfiguration? configuration, ILogger logger) : ILanguageModelClient -{ - private readonly LanguageModelConfiguration? _configuration = configuration; - private readonly ILogger _logger = logger; - private bool? _lmAvailable; - private Dictionary _cache = new(); - - public async Task GenerateCompletion(string prompt) - { - using var scope = _logger.BeginScope("Language Model"); - - if (_configuration == null || !_configuration.Enabled) - { - // LM turned off. Nothing to do, nothing to report - return null; - } - - if (!_lmAvailable.HasValue) - { - if (string.IsNullOrEmpty(_configuration.Url)) - { - _logger.LogError("URL is not set. Language model will be disabled"); - _lmAvailable = false; - return null; - } - - if (string.IsNullOrEmpty(_configuration.Model)) - { - _logger.LogError("Model is not set. Language model will be disabled"); - _lmAvailable = false; - return null; - } - - _logger.LogDebug("Checking availability..."); - _lmAvailable = await IsLmAvailable(); - - // we want to log this only once - if (!_lmAvailable.Value) - { - _logger.LogError("{model} at {url} is not available", _configuration.Model, _configuration.Url); - return null; - } - } - - if (!_lmAvailable.Value) - { - return null; - } - - if (_configuration.CacheResponses && _cache.TryGetValue(prompt, out var cachedResponse)) - { - _logger.LogDebug("Returning cached response for prompt: {prompt}", prompt); - return cachedResponse; - } - - var response = await GenerateCompletionInternal(prompt); - if (response == null) - { - return null; - } - if (response.Error is not null) - { - _logger.LogError(response.Error); - return null; - } - else - { - if (_configuration.CacheResponses && response.Response is not null) - { - _cache[prompt] = response.Response; - } - - return response.Response; - } - } - - private async Task GenerateCompletionInternal(string prompt) - { - Debug.Assert(_configuration != null, "Configuration is null"); - - try - { - using var client = new HttpClient(); - var url = $"{_configuration.Url}/api/generate"; - _logger.LogDebug("Requesting completion. Prompt: {prompt}", prompt); - - var response = await client.PostAsJsonAsync(url, - new - { - prompt, - model = _configuration.Model, - stream = false - } - ); - return await response.Content.ReadFromJsonAsync(); - } - catch (Exception ex) - { - _logger.LogError(ex, "Failed to generate completion"); - return null; - } - } - - private async Task IsLmAvailable() - { - Debug.Assert(_configuration != null, "Configuration is null"); - - _logger.LogDebug("Checking LM availability at {url}...", _configuration.Url); - - try - { - // check if lm is on - using var client = new HttpClient(); - var response = await client.GetAsync(_configuration.Url); - _logger.LogDebug("Response: {response}", response.StatusCode); - - if (!response.IsSuccessStatusCode) - { - return false; - } - - var testCompletion = await GenerateCompletionInternal("Are you there? Reply with a yes or no."); - if (testCompletion?.Error is not null) - { - _logger.LogError("Error: {error}", testCompletion.Error); - return false; - } - - return true; - } - catch (Exception ex) - { - _logger.LogError(ex, "Couldn't reach language model at {url}", _configuration.Url); - return false; - } - } -} \ No newline at end of file diff --git a/dev-proxy/LanguageModel/LanguageModelResponse.cs b/dev-proxy/LanguageModel/LanguageModelResponse.cs deleted file mode 100644 index a15fefaf..00000000 --- a/dev-proxy/LanguageModel/LanguageModelResponse.cs +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -namespace Microsoft.DevProxy.LanguageModel; - -public class LanguageModelResponse -{ - public string? Response { get; init; } - public string? Error { get; init; } -} \ No newline at end of file diff --git a/dev-proxy/Logging/ProxyConsoleFormatter.cs b/dev-proxy/Logging/ProxyConsoleFormatter.cs index 9cf5aefa..deb5717c 100644 --- a/dev-proxy/Logging/ProxyConsoleFormatter.cs +++ b/dev-proxy/Logging/ProxyConsoleFormatter.cs @@ -123,14 +123,17 @@ private void WriteMessageBoxedWithInvertedLabels(string? message, LogLevel logLe var (bgColor, fgColor) = GetLogLevelColor(logLevel); textWriter.WriteColoredMessage($" {label} ", bgColor, fgColor); - textWriter.Write($"{labelSpacing}{_boxSpacing}{(logLevel == LogLevel.Debug ? $"[{DateTime.Now}] " : "")}"); + textWriter.Write($"{labelSpacing}{_boxSpacing}{(logLevel == LogLevel.Debug ? $"[{DateTime.Now:T}] " : "")}"); if (_options.IncludeScopes && scopeProvider is not null) { scopeProvider.ForEachScope((scope, state) => { - state.Write(scope); - state.Write(": "); + if (scope is string scopeString) + { + textWriter.Write(scopeString); + textWriter.Write(": "); + } }, textWriter); } diff --git a/dev-proxy/Program.cs b/dev-proxy/Program.cs index e388cd3a..0c82a414 100644 --- a/dev-proxy/Program.cs +++ b/dev-proxy/Program.cs @@ -3,7 +3,6 @@ using Microsoft.DevProxy; using Microsoft.DevProxy.Abstractions; -using Microsoft.DevProxy.LanguageModel; using Microsoft.DevProxy.Logging; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Console; @@ -32,7 +31,7 @@ ILogger BuildLogger() var logger = BuildLogger(); -var lmClient = new LanguageModelClient(ProxyCommandHandler.Configuration.LanguageModel, logger); +var lmClient = new OllamaLanguageModelClient(ProxyCommandHandler.Configuration.LanguageModel, logger); IProxyContext context = new ProxyContext(ProxyCommandHandler.Configuration, ProxyEngine.Certificate, lmClient); ProxyHost proxyHost = new(); diff --git a/dev-proxy/ProxyCommandHandler.cs b/dev-proxy/ProxyCommandHandler.cs index 81d7b933..6769b533 100755 --- a/dev-proxy/ProxyCommandHandler.cs +++ b/dev-proxy/ProxyCommandHandler.cs @@ -6,7 +6,6 @@ using Microsoft.DevProxy.Abstractions; using System.CommandLine; using System.CommandLine.Invocation; -using Microsoft.DevProxy.LanguageModel; namespace Microsoft.DevProxy; diff --git a/dev-proxy/ProxyConfiguration.cs b/dev-proxy/ProxyConfiguration.cs index 239f7bca..0f24997b 100755 --- a/dev-proxy/ProxyConfiguration.cs +++ b/dev-proxy/ProxyConfiguration.cs @@ -4,7 +4,6 @@ using System.Runtime.Serialization; using System.Text.Json.Serialization; using Microsoft.DevProxy.Abstractions; -using Microsoft.DevProxy.LanguageModel; using Microsoft.Extensions.Logging; namespace Microsoft.DevProxy;