Skip to content

Commit

Permalink
Extends Dev Proxy with a language model. Closes #745 (#754)
Browse files Browse the repository at this point in the history
  • Loading branch information
waldekmastykarz authored Jun 5, 2024
1 parent af65ca6 commit 9e73c68
Show file tree
Hide file tree
Showing 11 changed files with 228 additions and 24 deletions.
4 changes: 4 additions & 0 deletions dev-proxy-abstractions/ILanguageModelClient.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
public interface ILanguageModelClient
{
Task<string?> GenerateCompletion(string prompt);
}
1 change: 1 addition & 0 deletions dev-proxy-abstractions/PluginEvents.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ public interface IProxyContext
{
IProxyConfiguration Configuration { get; }
X509Certificate2? Certificate { get; }
ILanguageModelClient LanguageModelClient { get; }
}

public class ThrottlerInfo
Expand Down
31 changes: 24 additions & 7 deletions dev-proxy-plugins/RequestLogs/OpenApiSpecGeneratorPlugin.cs
Original file line number Diff line number Diff line change
Expand Up @@ -298,14 +298,14 @@ public override void Register()
PluginEvents.AfterRecordingStop += AfterRecordingStop;
}

private Task AfterRecordingStop(object? sender, RecordingArgs e)
private async Task AfterRecordingStop(object? sender, RecordingArgs e)
{
Logger.LogInformation("Creating OpenAPI spec from recorded requests...");

if (!e.RequestLogs.Any())
{
Logger.LogDebug("No requests to process");
return Task.CompletedTask;
return;
}

var openApiDocs = new List<OpenApiDocument>();
Expand Down Expand Up @@ -334,7 +334,16 @@ request.Context is null ||
var pathItem = GetOpenApiPathItem(request.Context.Session);
var parametrizedPath = ParametrizePath(pathItem, request.Context.Session.HttpClient.Request.RequestUri);
var operationInfo = pathItem.Operations.First();
operationInfo.Value.OperationId = GetOperationId(operationInfo.Key.ToString(), parametrizedPath);
operationInfo.Value.OperationId = await GetOperationId(
operationInfo.Key.ToString(),
request.Context.Session.HttpClient.Request.RequestUri.GetLeftPart(UriPartial.Authority),
parametrizedPath
);
operationInfo.Value.Description = await GetOperationDescription(
operationInfo.Key.ToString(),
request.Context.Session.HttpClient.Request.RequestUri.GetLeftPart(UriPartial.Authority),
parametrizedPath
);
AddOrMergePathItem(openApiDocs, pathItem, request.Context.Session.HttpClient.Request.RequestUri, parametrizedPath);
}
catch (Exception ex)
Expand Down Expand Up @@ -370,8 +379,6 @@ request.Context is null ||
// store the generated OpenAPI specs in the global data
// for use by other plugins
e.GlobalData[GeneratedOpenApiSpecsKey] = generatedOpenApiSpecs;

return Task.CompletedTask;
}

/**
Expand Down Expand Up @@ -441,9 +448,18 @@ private string GetLastNonTokenSegment(string[] segments)
return "item";
}

private string GetOperationId(string method, string parametrizedPath)
private async Task<string> GetOperationId(string method, string serverUrl, string parametrizedPath)
{
var prompt = $"For the specified request, generate an operation ID, compatible with an OpenAPI spec. Respond with just the ID in plain-text format. For example, for request such as `GET https://api.contoso.com/books/{{books-id}}` you return `getBookById`. For a request like `GET https://api.contoso.com/books/{{books-id}}/authors` you return `getAuthorsForBookById`. Request: {method.ToUpper()} {serverUrl}{parametrizedPath}";
var id = await Context.LanguageModelClient.GenerateCompletion(prompt);
return id ?? $"{method}{parametrizedPath.Replace('/', '.')}";
}

private async Task<string> GetOperationDescription(string method, string serverUrl, string parametrizedPath)
{
return $"{method}{parametrizedPath.Replace('/', '.')}";
var prompt = $"You're an expert in OpenAPI. You help developers build great OpenAPI specs for use with LLMs. For the specified request, generate a one-sentence description. Respond with just the description. For example, for a request such as `GET https://api.contoso.com/books/{{books-id}}` you return `Get a book by ID`. Request: {method.ToUpper()} {serverUrl}{parametrizedPath}";
var description = await Context.LanguageModelClient.GenerateCompletion(prompt);
return description ?? $"{method} {parametrizedPath}";
}

/**
Expand Down Expand Up @@ -472,6 +488,7 @@ private OpenApiPathItem GetOpenApiPathItem(SessionEventArgs session)
};
var operation = new OpenApiOperation
{
// will be replaced later after the path has been parametrized
Description = $"{method} {resource}",
// will be replaced later after the path has been parametrized
OperationId = $"{method}.{resource}"
Expand Down
146 changes: 146 additions & 0 deletions dev-proxy/LanguageModel/LanguageModelClient.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

using System.Diagnostics;
using System.Net.Http.Json;
using Microsoft.Extensions.Logging;

namespace Microsoft.DevProxy.LanguageModel;

public class LanguageModelClient(LanguageModelConfiguration? configuration, ILogger logger) : ILanguageModelClient
{
private readonly LanguageModelConfiguration? _configuration = configuration;
private readonly ILogger _logger = logger;
private bool? _lmAvailable;
private Dictionary<string, string> _cache = new();

public async Task<string?> GenerateCompletion(string prompt)
{
using var scope = _logger.BeginScope("Language Model");

if (_configuration == null || !_configuration.Enabled)
{
// LM turned off. Nothing to do, nothing to report
return null;
}

if (!_lmAvailable.HasValue)
{
if (string.IsNullOrEmpty(_configuration.Url))
{
_logger.LogError("URL is not set. Language model will be disabled");
_lmAvailable = false;
return null;
}

if (string.IsNullOrEmpty(_configuration.Model))
{
_logger.LogError("Model is not set. Language model will be disabled");
_lmAvailable = false;
return null;
}

_logger.LogDebug("Checking availability...");
_lmAvailable = await IsLmAvailable();

// we want to log this only once
if (!_lmAvailable.Value)
{
_logger.LogError("{model} at {url} is not available", _configuration.Model, _configuration.Url);
return null;
}
}

if (!_lmAvailable.Value)
{
return null;
}

if (_configuration.CacheResponses && _cache.TryGetValue(prompt, out var cachedResponse))
{
_logger.LogDebug("Returning cached response for prompt: {prompt}", prompt);
return cachedResponse;
}

var response = await GenerateCompletionInternal(prompt);
if (response == null)
{
return null;
}
if (response.Error is not null)
{
_logger.LogError(response.Error);
return null;
}
else
{
if (_configuration.CacheResponses && response.Response is not null)
{
_cache[prompt] = response.Response;
}

return response.Response;
}
}

private async Task<LanguageModelResponse?> GenerateCompletionInternal(string prompt)
{
Debug.Assert(_configuration != null, "Configuration is null");

try
{
using var client = new HttpClient();
var url = $"{_configuration.Url}/api/generate";
_logger.LogDebug("Requesting completion. Prompt: {prompt}", prompt);

var response = await client.PostAsJsonAsync(url,
new
{
prompt,
model = _configuration.Model,
stream = false
}
);
return await response.Content.ReadFromJsonAsync<LanguageModelResponse>();
}
catch (Exception ex)
{
_logger.LogError(ex, "Failed to generate completion");
return null;
}
}

private async Task<bool> IsLmAvailable()
{
Debug.Assert(_configuration != null, "Configuration is null");

_logger.LogDebug("Checking LM availability at {url}...", _configuration.Url);

try
{
// check if lm is on
using var client = new HttpClient();
var response = await client.GetAsync(_configuration.Url);
_logger.LogDebug("Response: {response}", response.StatusCode);

if (!response.IsSuccessStatusCode)
{
return false;
}

var testCompletion = await GenerateCompletionInternal("Are you there? Reply with a yes or no.");
if (testCompletion?.Error is not null)
{
_logger.LogError("Error: {error}", testCompletion.Error);
return false;
}

return true;
}
catch (Exception ex)
{
_logger.LogError(ex, "Couldn't reach language model at {url}", _configuration.Url);
return false;
}
}
}
13 changes: 13 additions & 0 deletions dev-proxy/LanguageModel/LanguageModelConfiguration.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

namespace Microsoft.DevProxy.LanguageModel;

public class LanguageModelConfiguration
{
public bool Enabled { get; set; } = false;
// default Ollama URL
public string? Url { get; set; } = "http://localhost:11434";
public string? Model { get; set; } = "phi3";
public bool CacheResponses { get; set; } = true;
}
10 changes: 10 additions & 0 deletions dev-proxy/LanguageModel/LanguageModelResponse.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

namespace Microsoft.DevProxy.LanguageModel;

public class LanguageModelResponse
{
public string? Response { get; init; }
public string? Error { get; init; }
}
31 changes: 18 additions & 13 deletions dev-proxy/Logging/ProxyConsoleFormatter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -102,34 +102,39 @@ private void LogMessage<TState>(in LogEntry<TState> logEntry, IExternalScopeProv
var logLevel = logEntry.LogLevel;
var message = logEntry.Formatter(logEntry.State, logEntry.Exception);

WriteMessageBoxedWithInvertedLabels(message, logLevel, textWriter);
WriteMessageBoxedWithInvertedLabels(message, logLevel, scopeProvider, textWriter);

if (logEntry.Exception is not null)
{
textWriter.Write($" Exception Details: {logEntry.Exception}");
}

if (_options.IncludeScopes && scopeProvider is not null)
{
scopeProvider.ForEachScope((scope, state) =>
{
state.Write(" => ");
state.Write(scope);
}, textWriter);
}
textWriter.WriteLine();
}

private void WriteMessageBoxedWithInvertedLabels(string? message, LogLevel logLevel, TextWriter textWriter)
private void WriteMessageBoxedWithInvertedLabels(string? message, LogLevel logLevel, IExternalScopeProvider? scopeProvider, TextWriter textWriter)
{
if (message is null)
{
return;
}

var label = GetLogLevelString(logLevel);
var (bgColor, fgColor) = GetLogLevelColor(logLevel);

if (message is not null)
textWriter.WriteColoredMessage($" {label} ", bgColor, fgColor);
textWriter.Write($"{labelSpacing}{_boxSpacing}{(logLevel == LogLevel.Debug ? $"[{DateTime.Now}] " : "")}");

if (_options.IncludeScopes && scopeProvider is not null)
{
textWriter.WriteColoredMessage($" {label} ", bgColor, fgColor);
textWriter.Write($"{labelSpacing}{_boxSpacing}{(logLevel == LogLevel.Debug ? $"[{DateTime.Now}] " : "")}{message}");
scopeProvider.ForEachScope((scope, state) =>
{
state.Write(scope);
state.Write(": ");
}, textWriter);
}

textWriter.Write(message);
}

private void WriteLogMessageBoxedWithInvertedLabels(string[] message, MessageType messageType, TextWriter textWriter, bool lastMessage = false)
Expand Down
8 changes: 6 additions & 2 deletions dev-proxy/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using Microsoft.DevProxy;
using Microsoft.DevProxy.Abstractions;
using Microsoft.DevProxy.LanguageModel;
using Microsoft.DevProxy.Logging;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Console;
Expand All @@ -20,7 +21,9 @@ ILogger BuildLogger()
options.FormatterName = "devproxy";
options.LogToStandardErrorThreshold = LogLevel.Warning;
})
.AddConsoleFormatter<ProxyConsoleFormatter, ConsoleFormatterOptions>()
.AddConsoleFormatter<ProxyConsoleFormatter, ConsoleFormatterOptions>(options => {
options.IncludeScopes = true;
})
.AddRequestLogger(pluginEvents)
.SetMinimumLevel(ProxyHost.LogLevel ?? ProxyCommandHandler.Configuration.LogLevel);
});
Expand All @@ -29,7 +32,8 @@ ILogger BuildLogger()

var logger = BuildLogger();

IProxyContext context = new ProxyContext(ProxyCommandHandler.Configuration, ProxyEngine.Certificate);
var lmClient = new LanguageModelClient(ProxyCommandHandler.Configuration.LanguageModel, logger);
IProxyContext context = new ProxyContext(ProxyCommandHandler.Configuration, ProxyEngine.Certificate, lmClient);
ProxyHost proxyHost = new();

// this is where the root command is created which contains all commands and subcommands
Expand Down
1 change: 1 addition & 0 deletions dev-proxy/ProxyCommandHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using Microsoft.DevProxy.Abstractions;
using System.CommandLine;
using System.CommandLine.Invocation;
using Microsoft.DevProxy.LanguageModel;

namespace Microsoft.DevProxy;

Expand Down
3 changes: 2 additions & 1 deletion dev-proxy/ProxyConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Runtime.Serialization;
using System.Text.Json.Serialization;
using Microsoft.DevProxy.Abstractions;
using Microsoft.DevProxy.LanguageModel;
using Microsoft.Extensions.Logging;

namespace Microsoft.DevProxy;
Expand Down Expand Up @@ -36,6 +37,6 @@ public class ProxyConfiguration : IProxyConfiguration
public string ConfigFile { get; set; } = "devproxyrc.json";
[JsonConverter(typeof(JsonStringEnumConverter))]
public ReleaseType NewVersionNotification { get; set; } = ReleaseType.Stable;
public LanguageModelConfiguration? LanguageModel { get; set; }
public MockRequestHeader[]? FilterByHeaders { get; set; }
}

4 changes: 3 additions & 1 deletion dev-proxy/ProxyContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ internal class ProxyContext : IProxyContext
{
public IProxyConfiguration Configuration { get; }
public X509Certificate2? Certificate { get; }
public ILanguageModelClient LanguageModelClient { get; }

public ProxyContext(IProxyConfiguration configuration, X509Certificate2? certificate)
public ProxyContext(IProxyConfiguration configuration, X509Certificate2? certificate, ILanguageModelClient languageModelClient)
{
Configuration = configuration ?? throw new ArgumentNullException(nameof(configuration));
Certificate = certificate;
LanguageModelClient = languageModelClient;
}
}

0 comments on commit 9e73c68

Please sign in to comment.