Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Stability SDXL 1.0 for image generation #79

Merged
merged 2 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 60 additions & 43 deletions BotNet.Services/BotCommands/Art.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
using System.Threading;
using System.Threading.Tasks;
using BotNet.Services.RateLimit;
using BotNet.Services.Stability;
using BotNet.Services.ThisXDoesNotExist;
using Microsoft.Extensions.DependencyInjection;
using Telegram.Bot;
Expand All @@ -24,13 +23,31 @@ public static async Task GetRandomArtAsync(ITelegramBotClient botClient, IServic
try {
GENERATED_ART_RATE_LIMITER.ValidateActionRate(message.Chat.Id, message.From!.Id);

Message busyMessage = await botClient.SendTextMessageAsync(
chatId: message.Chat.Id,
text: "Generating image… ⏳",
parseMode: ParseMode.Markdown,
replyToMessageId: message.MessageId,
cancellationToken: cancellationToken
);

try {
byte[] image = await serviceProvider.GetRequiredService<StabilityClient>().GenerateImageAsync(commandArgument, CancellationToken.None);
byte[] image = await serviceProvider.GetRequiredService<Stability.Skills.ImageGenerationBot>().GenerateImageAsync(commandArgument, CancellationToken.None);
using MemoryStream imageStream = new(image);

try {
await botClient.DeleteMessageAsync(
chatId: busyMessage.Chat.Id,
messageId: busyMessage.MessageId,
cancellationToken: cancellationToken
);
} catch (OperationCanceledException) {
throw;
}

await botClient.SendPhotoAsync(
chatId: message.Chat.Id,
photo: new InputFileStream(imageStream, "art.jpg"),
photo: new InputFileStream(imageStream, "art.png"),
replyToMessageId: message.MessageId,
cancellationToken: cancellationToken);
} catch {
Expand Down Expand Up @@ -72,48 +89,48 @@ await botClient.SendTextMessageAsync(
}
}

public static async Task ModifyArtAsync(ITelegramBotClient botClient, IServiceProvider serviceProvider, Message message, string textPrompt, CancellationToken cancellationToken) {
if (message.ReplyToMessage is { } replyToMessage) {
using MemoryStream originalImageStream = new();
Telegram.Bot.Types.File fileInfo = message.ReplyToMessage.Photo?.Length > 0
? await botClient.GetInfoAndDownloadFileAsync(
fileId: message.ReplyToMessage.Photo.OrderByDescending(photoSize => photoSize.Width).First().FileId,
destination: originalImageStream,
cancellationToken: cancellationToken)
: await botClient.GetInfoAndDownloadFileAsync(
fileId: message.ReplyToMessage.Sticker!.FileId,
destination: originalImageStream,
cancellationToken: cancellationToken);
//public static async Task ModifyArtAsync(ITelegramBotClient botClient, IServiceProvider serviceProvider, Message message, string textPrompt, CancellationToken cancellationToken) {
// if (message.ReplyToMessage is { } replyToMessage) {
// using MemoryStream originalImageStream = new();
// Telegram.Bot.Types.File fileInfo = message.ReplyToMessage.Photo?.Length > 0
// ? await botClient.GetInfoAndDownloadFileAsync(
// fileId: message.ReplyToMessage.Photo.OrderByDescending(photoSize => photoSize.Width).First().FileId,
// destination: originalImageStream,
// cancellationToken: cancellationToken)
// : await botClient.GetInfoAndDownloadFileAsync(
// fileId: message.ReplyToMessage.Sticker!.FileId,
// destination: originalImageStream,
// cancellationToken: cancellationToken);

try {
MODIFY_ART_RATE_LIMITER.ValidateActionRate(message.Chat.Id, message.From!.Id);
// try {
// MODIFY_ART_RATE_LIMITER.ValidateActionRate(message.Chat.Id, message.From!.Id);

try {
byte[] image = await serviceProvider.GetRequiredService<StabilityClient>().ModifyImageAsync(originalImageStream.ToArray(), textPrompt, CancellationToken.None);
using MemoryStream imageStream = new(image);
// try {
// byte[] image = await serviceProvider.GetRequiredService<StabilityClient>().ModifyImageAsync(originalImageStream.ToArray(), textPrompt, CancellationToken.None);
// using MemoryStream imageStream = new(image);

await botClient.SendPhotoAsync(
chatId: message.Chat.Id,
photo: new InputFileStream(imageStream, "art.jpg"),
replyToMessageId: message.MessageId,
cancellationToken: cancellationToken);
} catch {
await botClient.SendTextMessageAsync(
chatId: message.Chat.Id,
text: "<code>Could not generate art</code>",
parseMode: ParseMode.Html,
replyToMessageId: message.MessageId,
cancellationToken: cancellationToken);
}
} catch (RateLimitExceededException exc) when (exc is { Cooldown: var cooldown }) {
await botClient.SendTextMessageAsync(
chatId: message.Chat.Id,
text: $"Anda belum mendapat giliran. Coba lagi {cooldown}.",
parseMode: ParseMode.Html,
replyToMessageId: message.MessageId,
cancellationToken: cancellationToken);
}
}
}
// await botClient.SendPhotoAsync(
// chatId: message.Chat.Id,
// photo: new InputFileStream(imageStream, "art.jpg"),
// replyToMessageId: message.MessageId,
// cancellationToken: cancellationToken);
// } catch {
// await botClient.SendTextMessageAsync(
// chatId: message.Chat.Id,
// text: "<code>Could not generate art</code>",
// parseMode: ParseMode.Html,
// replyToMessageId: message.MessageId,
// cancellationToken: cancellationToken);
// }
// } catch (RateLimitExceededException exc) when (exc is { Cooldown: var cooldown }) {
// await botClient.SendTextMessageAsync(
// chatId: message.Chat.Id,
// text: $"Anda belum mendapat giliran. Coba lagi {cooldown}.",
// parseMode: ParseMode.Html,
// replyToMessageId: message.MessageId,
// cancellationToken: cancellationToken);
// }
// }
//}
}
}
82 changes: 42 additions & 40 deletions BotNet.Services/BotCommands/OpenAI.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using BotNet.Services.OpenAI.Models;
using BotNet.Services.OpenAI.Skills;
using BotNet.Services.RateLimit;
using BotNet.Services.Stability.Skills;
using Microsoft.Extensions.DependencyInjection;
using RG.Ninja;
using SkiaSharp;
Expand Down Expand Up @@ -766,9 +767,8 @@ await botClient.SendTextMessageAsync(
}
}

private static readonly RateLimiter IMAGE_GENERATION_PER_USER_RATE_LIMITER = RateLimiter.PerUser(1, TimeSpan.FromMinutes(10));
private static readonly RateLimiter IMAGE_GENERATION_PER_CHAT_RATE_LIMITER = RateLimiter.PerChat(2, TimeSpan.FromMinutes(5));
private static readonly RateLimiter IMAGE_GENERATION_GLOBAL_RATE_LIMITER = RateLimiter.PerChat(1, TimeSpan.FromMinutes(1));
private static readonly RateLimiter IMAGE_GENERATION_PER_USER_RATE_LIMITER = RateLimiter.PerUser(1, TimeSpan.FromMinutes(5));
private static readonly RateLimiter IMAGE_GENERATION_PER_CHAT_RATE_LIMITER = RateLimiter.PerChat(2, TimeSpan.FromMinutes(3));
public static async Task StreamChatWithFriendlyBotAsync(
ITelegramBotClient botClient,
IServiceProvider serviceProvider,
Expand Down Expand Up @@ -831,46 +831,48 @@ await serviceProvider.GetRequiredService<FriendlyBot>().StreamChatAsync(
replyToMessageId: message.MessageId
);
break;
case ChatIntent.ImageGeneration:
IMAGE_GENERATION_PER_USER_RATE_LIMITER.ValidateActionRate(
chatId: message.Chat.Id,
userId: message.From.Id
);
IMAGE_GENERATION_PER_CHAT_RATE_LIMITER.ValidateActionRate(
chatId: message.Chat.Id,
userId: message.From.Id
);
IMAGE_GENERATION_GLOBAL_RATE_LIMITER.ValidateActionRate(
chatId: 0,
userId: 0
);
Message busyMessage = await botClient.SendTextMessageAsync(
chatId: message.Chat.Id,
text: "Generating image… ⏳",
parseMode: ParseMode.Markdown,
replyToMessageId: message.MessageId,
cancellationToken: cancellationToken
);
Uri generatedImageUrl = await serviceProvider.GetRequiredService<ImageGenerationBot>().GenerateImageAsync(
prompt: message.Text!,
cancellationToken: cancellationToken
);
try {
await botClient.DeleteMessageAsync(
chatId: busyMessage.Chat.Id,
messageId: busyMessage.MessageId,
case ChatIntent.ImageGeneration: {
IMAGE_GENERATION_PER_USER_RATE_LIMITER.ValidateActionRate(
chatId: message.Chat.Id,
userId: message.From.Id
);
IMAGE_GENERATION_PER_CHAT_RATE_LIMITER.ValidateActionRate(
chatId: message.Chat.Id,
userId: message.From.Id
);
Message busyMessage = await botClient.SendTextMessageAsync(
chatId: message.Chat.Id,
text: "Generating image… ⏳",
parseMode: ParseMode.Markdown,
replyToMessageId: message.MessageId,
cancellationToken: cancellationToken
);
//Uri generatedImageUrl = await serviceProvider.GetRequiredService<ImageGenerationBot>().GenerateImageAsync(
// prompt: message.Text!,
// cancellationToken: cancellationToken
//);
byte[] generatedImage = await serviceProvider.GetRequiredService<Stability.Skills.ImageGenerationBot>().GenerateImageAsync(
prompt: message.Text!,
cancellationToken: cancellationToken
);
using MemoryStream generatedImageStream = new(generatedImage);
try {
await botClient.DeleteMessageAsync(
chatId: busyMessage.Chat.Id,
messageId: busyMessage.MessageId,
cancellationToken: cancellationToken
);
} catch (OperationCanceledException) {
throw;
}
await botClient.SendPhotoAsync(
chatId: message.Chat.Id,
photo: new InputFileStream(generatedImageStream, "art.png"),
replyToMessageId: message.MessageId,
cancellationToken: cancellationToken
);
} catch (OperationCanceledException) {
throw;
break;
}
await botClient.SendPhotoAsync(
chatId: message.Chat.Id,
photo: new InputFileUrl(generatedImageUrl),
replyToMessageId: message.MessageId,
cancellationToken: cancellationToken
);
break;
}
}
} catch (RateLimitExceededException exc) when (exc is { Cooldown: var cooldown }) {
Expand Down
5 changes: 0 additions & 5 deletions BotNet.Services/BotNet.Services.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
<ItemGroup>
<None Remove="CopyPasta\Pasta.json" />
<None Remove="Meme\Images\ramad.jpg" />
<None Remove="Stability\generation.proto" />
<None Remove="FancyText\CharMaps\Bold.json" />
<None Remove="FancyText\CharMaps\BoldItalic.json" />
<None Remove="FancyText\CharMaps\Cursive.json" />
Expand Down Expand Up @@ -100,8 +99,4 @@
<ProjectReference Include="..\pehape\csharp\Pehape\Pehape.csproj" />
</ItemGroup>

<ItemGroup>
<Protobuf Include="Stability\generation.proto" GrpcServices="Client" />
</ItemGroup>

</Project>
13 changes: 13 additions & 0 deletions BotNet.Services/Stability/Models/TextToImageResponse.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using System.Collections.Generic;

namespace BotNet.Services.Stability.Models {
internal sealed record TextToImageResponse(
List<Artifact> Artifacts
);

internal sealed record Artifact(
string Base64,
string FinishReason,
int Seed
);
}
4 changes: 3 additions & 1 deletion BotNet.Services/Stability/ServiceCollectionExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
using Microsoft.Extensions.DependencyInjection;
using BotNet.Services.Stability.Skills;
using Microsoft.Extensions.DependencyInjection;

namespace BotNet.Services.Stability {
public static class ServiceCollectionExtensions {
public static IServiceCollection AddStabilityClient(this IServiceCollection services) {
services.AddSingleton<StabilityClient>();
services.AddSingleton<ImageGenerationBot>();
return services;
}
}
Expand Down
21 changes: 21 additions & 0 deletions BotNet.Services/Stability/Skills/ImageGenerationBot.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using System.Threading;
using System.Threading.Tasks;

namespace BotNet.Services.Stability.Skills {
public sealed class ImageGenerationBot(
StabilityClient stabilityClient
) {
private readonly StabilityClient _stabilityClient = stabilityClient;

public async Task<byte[]> GenerateImageAsync(
string prompt,
CancellationToken cancellationToken
) {
return await _stabilityClient.GenerateImageAsync(
engine: "stable-diffusion-xl-1024-v1-0",
promptText: prompt,
cancellationToken: cancellationToken
);
}
}
}
Loading
Loading