Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemini #105

Merged
merged 2 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions BotNet.CommandHandlers/AI/Gemini/GeminiTextPromptHandler.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
using BotNet.Commands;
using BotNet.Commands.AI.Gemini;
using BotNet.Commands.BotUpdate.Message;
using BotNet.Commands.ChatAggregate;
using BotNet.Commands.CommandPrioritization;
using BotNet.Services.Gemini;
using BotNet.Services.Gemini.Models;
using BotNet.Services.MarkdownV2;
using BotNet.Services.RateLimit;
using Microsoft.Extensions.Logging;
using Telegram.Bot;
using Telegram.Bot.Types;
using Telegram.Bot.Types.Enums;

namespace BotNet.CommandHandlers.AI.Gemini {
public sealed class GeminiTextPromptHandler(
ITelegramBotClient telegramBotClient,
GeminiClient geminiClient,
ITelegramMessageCache telegramMessageCache,
CommandPriorityCategorizer commandPriorityCategorizer,
ILogger<GeminiTextPromptHandler> logger
) : ICommandHandler<GeminiTextPrompt> {
internal static readonly RateLimiter CHAT_RATE_LIMITER = RateLimiter.PerChat(60, TimeSpan.FromMinutes(1));

private readonly ITelegramBotClient _telegramBotClient = telegramBotClient;
private readonly GeminiClient _geminiClient = geminiClient;
private readonly ITelegramMessageCache _telegramMessageCache = telegramMessageCache;
private readonly CommandPriorityCategorizer _commandPriorityCategorizer = commandPriorityCategorizer;
private readonly ILogger<GeminiTextPromptHandler> _logger = logger;

public Task Handle(GeminiTextPrompt textPrompt, CancellationToken cancellationToken) {
if (textPrompt.Command.Chat is not HomeGroupChat) {
return _telegramBotClient.SendTextMessageAsync(
chatId: textPrompt.Command.Chat.Id,
text: MarkdownV2Sanitizer.Sanitize("Gemini tidak bisa dipakai di sini."),
parseMode: ParseMode.MarkdownV2,
replyToMessageId: textPrompt.Command.MessageId,
cancellationToken: cancellationToken
);
}

try {
CHAT_RATE_LIMITER.ValidateActionRate(
chatId: textPrompt.Command.Chat.Id,
userId: textPrompt.Command.Sender.Id
);
} catch (RateLimitExceededException exc) {
return _telegramBotClient.SendTextMessageAsync(
chatId: textPrompt.Command.Chat.Id,
text: $"<code>Anda terlalu banyak memanggil AI. Coba lagi {exc.Cooldown}.</code>",
parseMode: ParseMode.Html,
replyToMessageId: textPrompt.Command.MessageId,
cancellationToken: cancellationToken
);
}

// Fire and forget
Task.Run(async () => {
List<Content> messages = [];

// Merge adjacent messages from same role
foreach (MessageBase message in textPrompt.Thread.Reverse()) {
Content content = Content.FromText(
role: message.Sender.GeminiRole,
text: message.Text
);

if (messages.Count > 0
&& messages[^1].Role == message.Sender.GeminiRole) {
messages[^1].Add(content);
} else {
messages.Add(content);
}
}

// Trim thread longer than 10 messages
while (messages.Count > 10) {
messages.RemoveAt(0);
}

// Thread must start with user message
while (messages.Count > 0
&& messages[0].Role != "user") {
messages.RemoveAt(0);
}

messages.Add(
Content.FromText("user", textPrompt.Prompt)
);

Message responseMessage = await _telegramBotClient.SendTextMessageAsync(
chatId: textPrompt.Command.Chat.Id,
text: MarkdownV2Sanitizer.Sanitize("… ⏳"),
parseMode: ParseMode.MarkdownV2,
replyToMessageId: textPrompt.Command.MessageId
);

string response = await _geminiClient.ChatAsync(
messages: messages,
maxTokens: 512,
cancellationToken: cancellationToken
);

// Finalize message
try {
responseMessage = await telegramBotClient.EditMessageTextAsync(
chatId: textPrompt.Command.Chat.Id,
messageId: responseMessage.MessageId,
text: MarkdownV2Sanitizer.Sanitize(response),
parseMode: ParseMode.MarkdownV2,
cancellationToken: cancellationToken
);
} catch (Exception exc) {
_logger.LogError(exc, null);
throw;
}

// Track thread
_telegramMessageCache.Add(
message: AIResponseMessage.FromMessage(
message: responseMessage,
replyToMessage: textPrompt.Command,
callSign: "Gemini",
commandPriorityCategorizer: _commandPriorityCategorizer
)
);
});

return Task.CompletedTask;
}
}
}
12 changes: 12 additions & 0 deletions BotNet.CommandHandlers/BotUpdate/Message/AICallCommandHandler.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using BotNet.Commands;
using BotNet.Commands.AI.Gemini;
using BotNet.Commands.AI.OpenAI;
using BotNet.Commands.BotUpdate.Message;
using BotNet.Services.OpenAI;
Expand Down Expand Up @@ -37,6 +38,17 @@ await _commandQueue.DispatchAsync(
);
break;
}
case "Gemini" when command.ImageFileId is null && command.ReplyToMessage?.ImageFileId is null: {
await _commandQueue.DispatchAsync(
command: GeminiTextPrompt.FromAICallCommand(
aiCallCommand: command,
thread: command.ReplyToMessage is { } replyToMessage
? _telegramMessageCache.GetThread(replyToMessage)
: Enumerable.Empty<MessageBase>()
)
);
break;
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using BotNet.Commands;
using BotNet.Commands.AI.Gemini;
using BotNet.Commands.AI.OpenAI;
using BotNet.Commands.BotUpdate.Message;

Expand All @@ -25,6 +26,18 @@ await _commandQueue.DispatchAsync(
)
);
break;
case "Gemini":
await _commandQueue.DispatchAsync(
command: GeminiTextPrompt.FromAIFollowUpMessage(
aIFollowUpMessage: command,
thread: command.ReplyToMessage is null
? Enumerable.Empty<MessageBase>()
: _telegramMessageCache.GetThread(
firstMessage: command.ReplyToMessage
)
)
);
break;
}
}
}
Expand Down
77 changes: 77 additions & 0 deletions BotNet.Commands/AI/Gemini/GeminiTextPrompt.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
using BotNet.Commands.BotUpdate.Message;

namespace BotNet.Commands.AI.Gemini {
public sealed record GeminiTextPrompt : ICommand {
public string Prompt { get; }
public HumanMessageBase Command { get; }
public IEnumerable<MessageBase> Thread { get; }

private GeminiTextPrompt(
string prompt,
HumanMessageBase command,
IEnumerable<MessageBase> thread
) {
Prompt = prompt;
Command = command;
Thread = thread;
}

public static GeminiTextPrompt FromAICallCommand(AICallCommand aiCallCommand, IEnumerable<MessageBase> thread) {
// Call sign must be Gemini
if (aiCallCommand.CallSign != "Gemini") {
throw new ArgumentException("Call sign must be Gemini", nameof(aiCallCommand));
}

// Prompt must be non-empty
if (string.IsNullOrWhiteSpace(aiCallCommand.Text)) {
throw new ArgumentException("Prompt must be non-empty", nameof(aiCallCommand));
}

// Non-empty thread must begin with reply to message
if (thread.FirstOrDefault() is {
MessageId: { } firstMessageId,
Chat.Id: { } firstChatId
}) {
if (firstMessageId != aiCallCommand.ReplyToMessage?.MessageId
|| firstChatId != aiCallCommand.Chat.Id) {
throw new ArgumentException("Thread must begin with reply to message", nameof(thread));
}
}

return new(
prompt: aiCallCommand.Text,
command: aiCallCommand,
thread: thread
);
}

public static GeminiTextPrompt FromAIFollowUpMessage(AIFollowUpMessage aIFollowUpMessage, IEnumerable<MessageBase> thread) {
// Call sign must be Gemini
if (aIFollowUpMessage.CallSign != "Gemini") {
throw new ArgumentException("Call sign must be Gemini", nameof(aIFollowUpMessage));
}

// Prompt must be non-empty
if (string.IsNullOrWhiteSpace(aIFollowUpMessage.Text)) {
throw new ArgumentException("Prompt must be non-empty", nameof(aIFollowUpMessage));
}

// Non-empty thread must begin with reply to message
if (thread.FirstOrDefault() is {
MessageId: { } firstMessageId,
Chat.Id: { } firstChatId
}) {
if (firstMessageId != aIFollowUpMessage.ReplyToMessage?.MessageId
|| firstChatId != aIFollowUpMessage.Chat.Id) {
throw new ArgumentException("Thread must begin with reply to message", nameof(thread));
}
}

return new(
prompt: aIFollowUpMessage.Text,
command: aIFollowUpMessage,
thread: thread
);
}
}
}
3 changes: 3 additions & 0 deletions BotNet.Commands/SenderAggregate/Sender.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ public abstract record SenderBase(
string Name
) {
public abstract string ChatGPTRole { get; }
public abstract string GeminiRole { get; }
}

public record HumanSender(
SenderId Id,
string Name
) : SenderBase(Id, Name) {
public override string ChatGPTRole => "user";
public override string GeminiRole => "user";

public static bool TryCreate(
Telegram.Bot.Types.User user,
Expand Down Expand Up @@ -51,6 +53,7 @@ public sealed record BotSender(
string Name
) : SenderBase(Id, Name) {
public override string ChatGPTRole => "assistant";
public override string GeminiRole => "model";

public static bool TryCreate(
Telegram.Bot.Types.User user,
Expand Down
54 changes: 54 additions & 0 deletions BotNet.Services/Gemini/GeminiClient.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Net.Http;
using System.Net.Http.Json;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using BotNet.Services.Gemini.Models;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;

namespace BotNet.Services.Gemini {
public class GeminiClient(
HttpClient httpClient,
IOptions<GeminiOptions> geminiOptionsAccessor,
ILogger<GeminiClient> logger
) {
private const string BASE_URL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent";
private readonly HttpClient _httpClient = httpClient;
private readonly string _apiKey = geminiOptionsAccessor.Value.ApiKey!;
private readonly ILogger<GeminiClient> _logger = logger;

public async Task<string> ChatAsync(IEnumerable<Content> messages, int maxTokens, CancellationToken cancellationToken) {
GeminiRequest geminiRequest = new(
Contents: messages.ToImmutableList(),
SafetySettings: null,
GenerationConfig: new(
MaxOutputTokens: maxTokens
)
);
using HttpRequestMessage request = new(HttpMethod.Post, BASE_URL + $"?key={_apiKey}") {
Headers = {
{ "Accept", "application/json" }
},
Content = JsonContent.Create(
inputValue: geminiRequest
)
};
using HttpResponseMessage response = await _httpClient.SendAsync(request, cancellationToken);
string responseContent = await response.Content.ReadAsStringAsync(cancellationToken);
response.EnsureSuccessStatusCode();

GeminiResponse? geminiResponse = JsonSerializer.Deserialize<GeminiResponse>(responseContent);
if (geminiResponse == null) return "";
if (geminiResponse.Candidates == null) return "";
if (geminiResponse.Candidates.Count == 0) return "";
Content? content = geminiResponse.Candidates[0].Content;
if (content == null) return "";
if (content.Parts == null) return "";
if (content.Parts.Count == 0) return "";
return content.Parts[0].Text ?? "";
}
}
}
5 changes: 5 additions & 0 deletions BotNet.Services/Gemini/GeminiOptions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
namespace BotNet.Services.Gemini {
public class GeminiOptions {
public string? ApiKey { get; set; }
}
}
11 changes: 11 additions & 0 deletions BotNet.Services/Gemini/Models/Candidate.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
using System.Collections.Immutable;
using System.Text.Json.Serialization;

namespace BotNet.Services.Gemini.Models {
public sealed record Candidate(
[property: JsonPropertyName("content")] Content? Content,
[property: JsonPropertyName("finishReason")] string? FinishReason,
[property: JsonPropertyName("index")] int? Index,
[property: JsonPropertyName("safetyRatings")] ImmutableList<SafetyRating>? SafetyRatings
);
}
22 changes: 22 additions & 0 deletions BotNet.Services/Gemini/Models/Content.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using System;
using System.Collections.Generic;
using System.Text.Json.Serialization;

namespace BotNet.Services.Gemini.Models {
public record Content(
[property: JsonPropertyName("role"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] string? Role,
[property: JsonPropertyName("parts")] List<Part>? Parts
) {
public static Content FromText(string role, string text) => new(
Role: role,
Parts: [
new(Text: text)
]
);

public void Add(Content content) {
if (content.Role != Role) throw new InvalidOperationException();
Parts!.AddRange(content.Parts!);
}
}
}
10 changes: 10 additions & 0 deletions BotNet.Services/Gemini/Models/GeminiRequest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
using System.Collections.Immutable;
using System.Text.Json.Serialization;

namespace BotNet.Services.Gemini.Models {
public sealed record GeminiRequest(
[property: JsonPropertyName("contents")] ImmutableList<Content> Contents,
[property: JsonPropertyName("safetySettings"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] ImmutableList<SafetySettings>? SafetySettings,
[property: JsonPropertyName("generationConfig"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] GenerationConfig? GenerationConfig
);
}
Loading
Loading