From 5bdf7505f6d15b2ef7fdc41792089ebae9860f8c Mon Sep 17 00:00:00 2001 From: GR Date: Mon, 27 Jan 2025 23:00:47 +0800 Subject: [PATCH] feat: add DeepSeek model client --- models/spring-ai-deepseek/README.md | 1 + models/spring-ai-deepseek/pom.xml | 63 + .../ai/deepseek/DeepSeekChatModel.java | 550 ++++++++ .../ai/deepseek/DeepSeekChatOptions.java | 554 +++++++++ .../ai/deepseek/aot/DeepSeekRuntimeHints.java | 42 + .../ai/deepseek/api/DeepSeekApi.java | 1101 +++++++++++++++++ .../DeepSeekStreamFunctionCallingHelper.java | 176 +++ .../ai/deepseek/api/ResponseFormat.java | 126 ++ .../api/common/DeepSeekConstants.java | 37 + .../ai/deepseek/metadata/DeepSeekUsage.java | 62 + .../resources/META-INF/spring/aot.factories | 2 + .../DeepSeekChatCompletionRequestTests.java | 53 + .../ai/deepseek/DeepSeekRetryTests.java | 148 +++ .../deepseek/DeepSeekTestConfiguration.java | 48 + .../aot/DeepSeekRuntimeHintsTests.java | 46 + .../ai/deepseek/api/DeepSeekApiIT.java | 57 + .../ai/deepseek/api/MockWeatherService.java | 95 ++ .../ai/deepseek/chat/ActorsFilms.java | 53 + .../DeepSeekChatModelFunctionCallingIT.java | 185 +++ .../ai/deepseek/chat/DeepSeekChatModelIT.java | 192 +++ .../chat/DeepSeekChatModelObservationIT.java | 180 +++ .../test/resources/prompts/system-message.st | 4 + pom.xml | 2 + spring-ai-bom/pom.xml | 12 + .../observation/conventions/AiProvider.java | 5 + .../src/main/antora/modules/ROOT/nav.adoc | 1 + .../ROOT/pages/api/chat/deepseek-chat.adoc | 251 ++++ spring-ai-spring-boot-autoconfigure/pom.xml | 8 + .../deepseek/DeepSeekAutoConfiguration.java | 103 ++ .../deepseek/DeepSeekChatProperties.java | 89 ++ .../deepseek/DeepSeekCommonProperties.java | 37 + .../deepseek/DeepSeekParentProperties.java | 46 + ...ot.autoconfigure.AutoConfiguration.imports | 1 + .../deepseek/DeepSeekAutoConfigurationIT.java | 76 ++ .../deepseek/DeepSeekPropertiesTests.java | 162 +++ .../tool/DeepSeekFunctionCallbackIT.java | 124 ++ .../tool/FunctionCallbackInPromptIT.java | 117 ++ ...nctionCallbackWithPlainFunctionBeanIT.java | 175 +++ .../deepseek/tool/MockWeatherService.java | 95 ++ .../spring-ai-starter-deepseek/pom.xml | 58 + 40 files changed, 5137 insertions(+) create mode 100644 models/spring-ai-deepseek/README.md create mode 100644 models/spring-ai-deepseek/pom.xml create mode 100644 models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java create mode 100644 models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatOptions.java create mode 100644 models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/aot/DeepSeekRuntimeHints.java create mode 100644 models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekApi.java create mode 100644 models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekStreamFunctionCallingHelper.java create mode 100644 models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/ResponseFormat.java create mode 100644 models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/common/DeepSeekConstants.java create mode 100644 models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/metadata/DeepSeekUsage.java create mode 100644 models/spring-ai-deepseek/src/main/resources/META-INF/spring/aot.factories create mode 100644 models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/DeepSeekChatCompletionRequestTests.java create mode 100644 models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/DeepSeekRetryTests.java create mode 100644 models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/DeepSeekTestConfiguration.java create mode 100644 models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/aot/DeepSeekRuntimeHintsTests.java create mode 100644 models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/api/DeepSeekApiIT.java create mode 100644 models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/api/MockWeatherService.java create mode 100644 models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/chat/ActorsFilms.java create mode 100644 models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/chat/DeepSeekChatModelFunctionCallingIT.java create mode 100644 models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/chat/DeepSeekChatModelIT.java create mode 100644 models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/chat/DeepSeekChatModelObservationIT.java create mode 100644 models/spring-ai-deepseek/src/test/resources/prompts/system-message.st create mode 100644 spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/deepseek-chat.adoc create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekAutoConfiguration.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekChatProperties.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekCommonProperties.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekParentProperties.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekAutoConfigurationIT.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekPropertiesTests.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/deepseek/tool/DeepSeekFunctionCallbackIT.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/deepseek/tool/FunctionCallbackInPromptIT.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/deepseek/tool/FunctionCallbackWithPlainFunctionBeanIT.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/deepseek/tool/MockWeatherService.java create mode 100644 spring-ai-spring-boot-starters/spring-ai-starter-deepseek/pom.xml diff --git a/models/spring-ai-deepseek/README.md b/models/spring-ai-deepseek/README.md new file mode 100644 index 00000000000..2a084525110 --- /dev/null +++ b/models/spring-ai-deepseek/README.md @@ -0,0 +1 @@ +[DeepSeek Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/deepseek-chat.html) \ No newline at end of file diff --git a/models/spring-ai-deepseek/pom.xml b/models/spring-ai-deepseek/pom.xml new file mode 100644 index 00000000000..932dc1eb389 --- /dev/null +++ b/models/spring-ai-deepseek/pom.xml @@ -0,0 +1,63 @@ + + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-SNAPSHOT + ../../pom.xml + + spring-ai-deepseek + jar + Spring AI DeepSeek + DeepSeek support + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + org.springframework.ai + spring-ai-core + ${project.parent.version} + + + + org.springframework.ai + spring-ai-retry + ${project.parent.version} + + + + + org.springframework + spring-context-support + + + + org.springframework.boot + spring-boot-starter-logging + + + + + org.springframework.ai + spring-ai-test + ${project.version} + test + + + + io.micrometer + micrometer-observation-test + test + + + + diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java new file mode 100644 index 00000000000..11fd195846f --- /dev/null +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java @@ -0,0 +1,550 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.deepseek; + +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.metadata.*; +import org.springframework.ai.chat.model.*; +import org.springframework.ai.chat.observation.ChatModelObservationContext; +import org.springframework.ai.chat.observation.ChatModelObservationConvention; +import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; +import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.deepseek.api.DeepSeekApi; +import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletion; +import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletion.Choice; +import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage; +import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage.ChatCompletionFunction; +import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage.MediaContent; +import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage.ToolCall; +import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionRequest; +import org.springframework.ai.deepseek.api.common.DeepSeekConstants; +import org.springframework.ai.deepseek.metadata.DeepSeekUsage; +import org.springframework.ai.model.Media; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallbackResolver; +import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.http.ResponseEntity; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.*; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; + +/** + * {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal DeepSeek} + * backed by {@link DeepSeekApi}. + * + * @author Geng Rong + */ +public class DeepSeekChatModel extends AbstractToolCallSupport implements ChatModel { + + private static final Logger logger = LoggerFactory.getLogger(DeepSeekChatModel.class); + + private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); + + /** + * The default options used for the chat completion requests. + */ + private final DeepSeekChatOptions defaultOptions; + + /** + * The retry template used to retry the DeepSeek API calls. + */ + public final RetryTemplate retryTemplate; + + /** + * Low-level access to the DeepSeek API. + */ + private final DeepSeekApi deepSeekApi; + + /** + * Observation registry used for instrumentation. + */ + private final ObservationRegistry observationRegistry; + + /** + * Conventions to use for generating observations. + */ + private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; + + /** + * Creates an instance of the DeepSeekChatModel. + * @param deepSeekApi The DeepSeekApi instance to be used for interacting with the + * DeepSeek Chat API. + * @throws IllegalArgumentException if deepSeekApi is null + */ + public DeepSeekChatModel(DeepSeekApi deepSeekApi) { + this(deepSeekApi, DeepSeekChatOptions.builder().model(DeepSeekApi.DEFAULT_CHAT_MODEL).temperature(0.7).build()); + } + + /** + * Initializes an instance of the DeepSeekChatModel. + * @param deepSeekApi The DeepSeekApi instance to be used for interacting with the + * DeepSeek Chat API. + * @param options The DeepSeekChatOptions to configure the chat model. + */ + public DeepSeekChatModel(DeepSeekApi deepSeekApi, DeepSeekChatOptions options) { + this(deepSeekApi, options, null, RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + /** + * Initializes a new instance of the DeepSeekChatModel. + * @param deepSeekApi The DeepSeekApi instance to be used for interacting with the + * DeepSeek Chat API. + * @param options The DeepSeekChatOptions to configure the chat model. + * @param functionCallbackResolver The function callback resolver. + * @param retryTemplate The retry template. + */ + public DeepSeekChatModel(DeepSeekApi deepSeekApi, DeepSeekChatOptions options, + FunctionCallbackResolver functionCallbackResolver, RetryTemplate retryTemplate) { + this(deepSeekApi, options, functionCallbackResolver, List.of(), retryTemplate); + } + + /** + * Initializes a new instance of the DeepSeekChatModel. + * @param deepSeekApi The DeepSeekApi instance to be used for interacting with the + * DeepSeek Chat API. + * @param options The DeepSeekChatOptions to configure the chat model. + * @param functionCallbackResolver The function callback resolver. + * @param toolFunctionCallbacks The tool function callbacks. + * @param retryTemplate The retry template. + */ + public DeepSeekChatModel(DeepSeekApi deepSeekApi, DeepSeekChatOptions options, + FunctionCallbackResolver functionCallbackResolver, List toolFunctionCallbacks, + RetryTemplate retryTemplate) { + this(deepSeekApi, options, functionCallbackResolver, toolFunctionCallbacks, retryTemplate, + ObservationRegistry.NOOP); + } + + /** + * Initializes a new instance of the DeepSeekChatModel. + * @param deepSeekApi The DeepSeekApi instance to be used for interacting with the + * DeepSeek Chat API. + * @param options The DeepSeekChatOptions to configure the chat model. + * @param functionCallbackResolver The function callback resolver. + * @param toolFunctionCallbacks The tool function callbacks. + * @param retryTemplate The retry template. + * @param observationRegistry The ObservationRegistry used for instrumentation. + */ + public DeepSeekChatModel(DeepSeekApi deepSeekApi, DeepSeekChatOptions options, + FunctionCallbackResolver functionCallbackResolver, List toolFunctionCallbacks, + RetryTemplate retryTemplate, ObservationRegistry observationRegistry) { + + super(functionCallbackResolver, options, toolFunctionCallbacks); + + Assert.notNull(deepSeekApi, "DeepSeekApi must not be null"); + Assert.notNull(options, "Options must not be null"); + Assert.notNull(retryTemplate, "RetryTemplate must not be null"); + Assert.isTrue(CollectionUtils.isEmpty(options.getFunctionCallbacks()), + "The default function callbacks must be set via the toolFunctionCallbacks constructor parameter"); + Assert.notNull(observationRegistry, "ObservationRegistry must not be null"); + + this.deepSeekApi = deepSeekApi; + this.defaultOptions = options; + this.retryTemplate = retryTemplate; + this.observationRegistry = observationRegistry; + } + + @Override + public ChatResponse call(Prompt prompt) { + return this.internalCall(prompt, null); + } + + public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { + + ChatCompletionRequest request = createRequest(prompt, false); + + ChatModelObservationContext observationContext = ChatModelObservationContext.builder() + .prompt(prompt) + .provider(DeepSeekConstants.PROVIDER_NAME) + .requestOptions(buildRequestOptions(request)) + .build(); + + ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION + .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationRegistry) + .observe(() -> { + + ResponseEntity completionEntity = this.retryTemplate + .execute(ctx -> this.deepSeekApi.chatCompletionEntity(request)); + + var chatCompletion = completionEntity.getBody(); + + if (chatCompletion == null) { + logger.warn("No chat completion returned for prompt: {}", prompt); + return new ChatResponse(List.of()); + } + + List choices = chatCompletion.choices(); + if (choices == null) { + logger.warn("No choices returned for prompt: {}", prompt); + return new ChatResponse(List.of()); + } + + List generations = choices.stream().map(choice -> { + // @formatter:off + Map metadata = Map.of( + "id", chatCompletion.id() != null ? chatCompletion.id() : "", + "role", choice.message().role() != null ? choice.message().role().name() : "", + "index", choice.index(), + "finishReason", choice.finishReason() != null ? choice.finishReason().name() : ""); + // @formatter:on + return buildGeneration(choice, metadata); + }).toList(); + + // Current usage + DeepSeekApi.Usage usage = completionEntity.getBody().usage(); + Usage currentChatResponseUsage = usage != null ? DeepSeekUsage.from(usage) : new EmptyUsage(); + Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse); + ChatResponse chatResponse = new ChatResponse(generations, + from(completionEntity.getBody(), accumulatedUsage)); + + observationContext.setResponse(chatResponse); + + return chatResponse; + + }); + + if (!isProxyToolCalls(prompt, this.defaultOptions) + && isToolCall(response, Set.of(DeepSeekApi.ChatCompletionFinishReason.TOOL_CALLS.name(), + DeepSeekApi.ChatCompletionFinishReason.STOP.name()))) { + var toolCallConversation = handleToolCalls(prompt, response); + // Recursively call the call method with the tool call message + // conversation that contains the call responses. + return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), response); + } + + return response; + } + + @Override + public Flux stream(Prompt prompt) { + return internalStream(prompt, null); + } + + public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse) { + return Flux.deferContextual(contextView -> { + ChatCompletionRequest request = createRequest(prompt, true); + + Flux completionChunks = this.deepSeekApi.chatCompletionStream(request); + + // For chunked responses, only the first chunk contains the choice role. + // The rest of the chunks with same ID share the same role. + ConcurrentHashMap roleMap = new ConcurrentHashMap<>(); + + final ChatModelObservationContext observationContext = ChatModelObservationContext.builder() + .prompt(prompt) + .provider(DeepSeekConstants.PROVIDER_NAME) + .requestOptions(buildRequestOptions(request)) + .build(); + + Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( + this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationRegistry); + + observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start(); + + Flux chatResponse = completionChunks.map(this::chunkToChatCompletion) + .switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> { + try { + String id = chatCompletion2.id(); + + List generations = chatCompletion2.choices().stream().map(choice -> { + if (choice.message().role() != null) { + roleMap.putIfAbsent(id, choice.message().role().name()); + } + + // @formatter:off + Map metadata = Map.of( + "id", chatCompletion2.id(), + "role", roleMap.getOrDefault(id, ""), + "finishReason", choice.finishReason() != null ? choice.finishReason().name() : "" + ); + // @formatter:on + return buildGeneration(choice, metadata); + }).toList(); + DeepSeekApi.Usage usage = chatCompletion2.usage(); + Usage currentUsage = (usage != null) ? DeepSeekUsage.from(usage) : new EmptyUsage(); + Usage cumulativeUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse); + + return new ChatResponse(generations, from(chatCompletion2, cumulativeUsage)); + } + catch (Exception e) { + logger.error("Error processing chat completion", e); + return new ChatResponse(List.of()); + } + + })); + + // @formatter:off + Flux flux = chatResponse.flatMap(response -> { + + if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(response, Set.of(DeepSeekApi.ChatCompletionFinishReason.TOOL_CALLS.name(), + DeepSeekApi.ChatCompletionFinishReason.STOP.name()))) { + var toolCallConversation = handleToolCalls(prompt, response); + // Recursively call the stream method with the tool call message + // conversation that contains the call responses. + return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), response); + } + else { + return Flux.just(response); + } + }) + .doOnError(observation::error) + .doFinally(s -> observation.stop()) + .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); + // @formatter:on + + return new MessageAggregator().aggregate(flux, observationContext::setResponse); + + }); + } + + private Generation buildGeneration(Choice choice, Map metadata) { + List toolCalls = choice.message().toolCalls() == null ? List.of() + : choice.message() + .toolCalls() + .stream() + .map(toolCall -> new AssistantMessage.ToolCall(toolCall.id(), "function", + toolCall.function().name(), toolCall.function().arguments())) + .toList(); + + String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : ""); + var generationMetadataBuilder = ChatGenerationMetadata.builder().finishReason(finishReason); + + String textContent = choice.message().content(); + + AssistantMessage assistantMessage = new AssistantMessage(textContent, metadata, toolCalls); + return new Generation(assistantMessage, generationMetadataBuilder.build()); + } + + private ChatResponseMetadata from(DeepSeekApi.ChatCompletion result, Usage usage) { + Assert.notNull(result, "DeepSeek ChatCompletionResult must not be null"); + var builder = ChatResponseMetadata.builder() + .id(result.id() != null ? result.id() : "") + .usage(usage) + .model(result.model() != null ? result.model() : "") + .keyValue("created", result.created() != null ? result.created() : 0L) + .keyValue("system-fingerprint", result.systemFingerprint() != null ? result.systemFingerprint() : ""); + return builder.build(); + } + + private ChatResponseMetadata from(ChatResponseMetadata chatResponseMetadata, Usage usage) { + Assert.notNull(chatResponseMetadata, "DeepSeek ChatResponseMetadata must not be null"); + var builder = ChatResponseMetadata.builder() + .id(chatResponseMetadata.getId() != null ? chatResponseMetadata.getId() : "") + .usage(usage) + .model(chatResponseMetadata.getModel() != null ? chatResponseMetadata.getModel() : ""); + return builder.build(); + } + + /** + * Convert the ChatCompletionChunk into a ChatCompletion. The Usage is set to null. + * @param chunk the ChatCompletionChunk to convert + * @return the ChatCompletion + */ + private DeepSeekApi.ChatCompletion chunkToChatCompletion(DeepSeekApi.ChatCompletionChunk chunk) { + List choices = chunk.choices() + .stream() + .map(chunkChoice -> new Choice(chunkChoice.finishReason(), chunkChoice.index(), chunkChoice.delta(), + chunkChoice.logprobs())) + .toList(); + + return new DeepSeekApi.ChatCompletion(chunk.id(), choices, chunk.created(), chunk.model(), chunk.serviceTier(), + chunk.systemFingerprint(), chunk.usage()); + } + + /** + * Accessible for testing. + */ + ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { + + List chatCompletionMessages = prompt.getInstructions().stream().map(message -> { + if (message.getMessageType() == MessageType.USER || message.getMessageType() == MessageType.SYSTEM) { + Object content = message.getText(); + if (message instanceof UserMessage userMessage) { + if (!CollectionUtils.isEmpty(userMessage.getMedia())) { + List contentList = new ArrayList<>(List.of(new MediaContent(message.getText()))); + + contentList.addAll(userMessage.getMedia().stream().map(this::mapToMediaContent).toList()); + + content = contentList; + } + } + + return List.of(new ChatCompletionMessage(content, + ChatCompletionMessage.Role.valueOf(message.getMessageType().name()))); + } + else if (message.getMessageType() == MessageType.ASSISTANT) { + var assistantMessage = (AssistantMessage) message; + List toolCalls = null; + if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) { + toolCalls = assistantMessage.getToolCalls().stream().map(toolCall -> { + var function = new ChatCompletionFunction(toolCall.name(), toolCall.arguments()); + return new ToolCall(toolCall.id(), toolCall.type(), function); + }).toList(); + } + return List.of(new ChatCompletionMessage(assistantMessage.getText(), + ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls)); + } + else if (message.getMessageType() == MessageType.TOOL) { + ToolResponseMessage toolMessage = (ToolResponseMessage) message; + + toolMessage.getResponses() + .forEach(response -> Assert.isTrue(response.id() != null, "ToolResponseMessage must have an id")); + return toolMessage.getResponses() + .stream() + .map(tr -> new ChatCompletionMessage(tr.responseData(), ChatCompletionMessage.Role.TOOL, tr.name(), + tr.id(), null)) + .toList(); + } + else { + throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType()); + } + }).flatMap(List::stream).toList(); + + ChatCompletionRequest request = new ChatCompletionRequest(chatCompletionMessages, stream); + + Set enabledToolsToUse = new HashSet<>(); + + if (prompt.getOptions() != null) { + DeepSeekChatOptions updatedRuntimeOptions = null; + + if (prompt.getOptions() instanceof FunctionCallingOptions) { + updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(((FunctionCallingOptions) prompt.getOptions()), + FunctionCallingOptions.class, DeepSeekChatOptions.class); + } + else { + updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, + DeepSeekChatOptions.class); + } + + enabledToolsToUse.addAll(this.runtimeFunctionCallbackConfigurations(updatedRuntimeOptions)); + + request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, ChatCompletionRequest.class); + } + + if (!CollectionUtils.isEmpty(this.defaultOptions.getFunctions())) { + enabledToolsToUse.addAll(this.defaultOptions.getFunctions()); + } + + request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class); + + // Add the enabled functions definitions to the request's tools parameter. + if (!CollectionUtils.isEmpty(enabledToolsToUse)) { + + request = ModelOptionsUtils.merge( + DeepSeekChatOptions.builder().tools(this.getFunctionTools(enabledToolsToUse)).build(), request, + ChatCompletionRequest.class); + } + + return request; + } + + private MediaContent mapToMediaContent(Media media) { + var mimeType = media.getMimeType(); + if (MimeTypeUtils.parseMimeType("audio/mp3").equals(mimeType)) { + return new MediaContent( + new MediaContent.InputAudio(fromAudioData(media.getData()), MediaContent.InputAudio.Format.MP3)); + } + if (MimeTypeUtils.parseMimeType("audio/wav").equals(mimeType)) { + return new MediaContent( + new MediaContent.InputAudio(fromAudioData(media.getData()), MediaContent.InputAudio.Format.WAV)); + } + else { + return new MediaContent( + new MediaContent.ImageUrl(this.fromMediaData(media.getMimeType(), media.getData()))); + } + } + + private String fromAudioData(Object audioData) { + if (audioData instanceof byte[] bytes) { + return Base64.getEncoder().encodeToString(bytes); + } + throw new IllegalArgumentException("Unsupported audio data type: " + audioData.getClass().getSimpleName()); + } + + private String fromMediaData(MimeType mimeType, Object mediaContentData) { + if (mediaContentData instanceof byte[] bytes) { + // Assume the bytes are an image. So, convert the bytes to a base64 encoded + // following the prefix pattern. + return String.format("data:%s;base64,%s", mimeType.toString(), Base64.getEncoder().encodeToString(bytes)); + } + else if (mediaContentData instanceof String text) { + // Assume the text is a URLs or a base64 encoded image prefixed by the user. + return text; + } + else { + throw new IllegalArgumentException( + "Unsupported media data type: " + mediaContentData.getClass().getSimpleName()); + } + } + + private List getFunctionTools(Set functionNames) { + return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> { + var function = new DeepSeekApi.FunctionTool.Function(functionCallback.getDescription(), + functionCallback.getName(), functionCallback.getInputTypeSchema()); + return new DeepSeekApi.FunctionTool(function); + }).toList(); + } + + private ChatOptions buildRequestOptions(DeepSeekApi.ChatCompletionRequest request) { + return ChatOptions.builder() + .model(request.model()) + .frequencyPenalty(request.frequencyPenalty()) + .maxTokens(request.maxTokens()) + .presencePenalty(request.presencePenalty()) + .stopSequences(request.stop()) + .temperature(request.temperature()) + .topP(request.topP()) + .build(); + } + + @Override + public ChatOptions getDefaultOptions() { + return DeepSeekChatOptions.fromOptions(this.defaultOptions); + } + + @Override + public String toString() { + return "DeepSeekChatModel [defaultOptions=" + this.defaultOptions + "]"; + } + + /** + * Use the provided convention for reporting observation data + * @param observationConvention The provided convention + */ + public void setObservationConvention(ChatModelObservationConvention observationConvention) { + Assert.notNull(observationConvention, "observationConvention cannot be null"); + this.observationConvention = observationConvention; + } + +} diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatOptions.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatOptions.java new file mode 100644 index 00000000000..6f6c96730a7 --- /dev/null +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatOptions.java @@ -0,0 +1,554 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.deepseek; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.deepseek.api.DeepSeekApi; +import org.springframework.ai.deepseek.api.ResponseFormat; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.util.Assert; + +import java.util.*; + +/** + * Chat completions options for the DeepSeek chat API. + * DeepSeek + * chat completion + * + * @author Geng Rong + */ +@JsonInclude(Include.NON_NULL) +public class DeepSeekChatOptions implements FunctionCallingOptions { + + // @formatter:off + /** + * ID of the model to use. You can use either usedeepseek-coder or deepseek-chat. + */ + private @JsonProperty("model") String model; + /** + * Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing + * frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. + */ + private @JsonProperty("frequency_penalty") Double frequencyPenalty; + /** + * The maximum number of tokens that can be generated in the chat completion. + * The total length of input tokens and generated tokens is limited by the model's context length. + */ + private @JsonProperty("max_tokens") Integer maxTokens; + /** + * Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they + * appear in the text so far, increasing the model's likelihood to talk about new topics. + */ + private @JsonProperty("presence_penalty") Double presencePenalty; + /** + * An object specifying the format that the model must output. Setting to { "type": + * "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON. + */ + private @JsonProperty("response_format") ResponseFormat responseFormat; + /** + * A string or a list containing up to 4 strings, upon encountering these words, the API will cease generating more tokens. + */ + private @JsonProperty("stop") List stop; + /** + * What sampling temperature to use, between 0 and 2. + * Higher values like 0.8 will make the output more random, + * while lower values like 0.2 will make it more focused and deterministic. + * We generally recommend altering this or top_p but not both. + */ + private @JsonProperty("temperature") Double temperature; + /** + * An alternative to sampling with temperature, called nucleus sampling, + * where the model considers the results of the tokens with top_p probability mass. + * So 0.1 means only the tokens comprising the top 10% probability mass are considered. + * We generally recommend altering this or temperature but not both. + */ + private @JsonProperty("top_p") Double topP; + /** + * Whether to return log probabilities of the output tokens or not. + * If true, returns the log probabilities of each output token returned in the content of message. + */ + private @JsonProperty("logprobs") Boolean logprobs; + /** + * An integer between 0 and 20 specifying the number of most likely tokens to return at each token position, + * each with an associated log probability. logprobs must be set to true if this parameter is used. + */ + private @JsonProperty("top_logprobs") Integer topLogprobs; + + + private @JsonProperty("tools") List tools; + + /** + * Controls which (if any) function is called by the model. none means the model will + * not call a function and instead generates a message. auto means the model can pick + * between generating a message or calling a function. Specifying a particular + * function via {"type: "function", "function": {"name": "my_function"}} forces the + * model to call that function. none is the default when no functions are present. + * auto is the default if functions are present. Use the + * {@link DeepSeekApi.ChatCompletionRequest.ToolChoiceBuilder} to create a tool choice + * object. + */ + private @JsonProperty("tool_choice") Object toolChoice; + + /** + * DeepSeek Tool Function Callbacks to register with the ChatModel. For Prompt Options + * the functionCallbacks are automatically enabled for the duration of the prompt + * execution. For Default Options the functionCallbacks are registered but disabled by + * default. Use the enableFunctions to set the functions from the registry to be used + * by the ChatModel chat completion requests. + */ + @JsonIgnore + private List functionCallbacks = new ArrayList<>(); + + /** + * List of functions, identified by their names, to configure for function calling in + * the chat completion requests. Functions with those names must exist in the + * functionCallbacks registry. The {@link #functionCallbacks} from the PromptOptions + * are automatically enabled for the duration of the prompt execution. + * + * Note that function enabled with the default options are enabled for all chat + * completion requests. This could impact the token count and the billing. If the + * functions is set in a prompt options, then the enabled functions are only active + * for the duration of this prompt execution. + */ + @JsonIgnore + private Set functions = new HashSet<>(); + + /** + * If true, the Spring AI will not handle the function calls internally, but will proxy them to the client. + * It is the client's responsibility to handle the function calls, dispatch them to the appropriate function, and return the results. + * If false, the Spring AI will handle the function calls internally. + */ + @JsonIgnore + private Boolean proxyToolCalls; + + @JsonIgnore + private Map toolContext; + + public static Builder builder() { + return new Builder(); + } + + // @formatter:on + + @Override + public List getFunctionCallbacks() { + return this.functionCallbacks; + } + + @Override + public void setFunctionCallbacks(List functionCallbacks) { + this.functionCallbacks = functionCallbacks; + } + + @Override + public Set getFunctions() { + return this.functions; + } + + public void setFunctions(Set functionNames) { + this.functions = functionNames; + } + + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + public Double getFrequencyPenalty() { + return this.frequencyPenalty; + } + + public void setFrequencyPenalty(Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + @Override + public Integer getMaxTokens() { + return this.maxTokens; + } + + public void setMaxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + } + + @Override + public Double getPresencePenalty() { + return this.presencePenalty; + } + + public void setPresencePenalty(Double presencePenalty) { + this.presencePenalty = presencePenalty; + } + + public ResponseFormat getResponseFormat() { + return this.responseFormat; + } + + public void setResponseFormat(ResponseFormat responseFormat) { + this.responseFormat = responseFormat; + } + + @Override + @JsonIgnore + public List getStopSequences() { + return getStop(); + } + + @JsonIgnore + public void setStopSequences(List stopSequences) { + setStop(stopSequences); + } + + public List getStop() { + return this.stop; + } + + public void setStop(List stop) { + this.stop = stop; + } + + @Override + public Double getTemperature() { + return this.temperature; + } + + public void setTemperature(Double temperature) { + this.temperature = temperature; + } + + @Override + public Double getTopP() { + return this.topP; + } + + public void setTopP(Double topP) { + this.topP = topP; + } + + public List getTools() { + return this.tools; + } + + public void setTools(List tools) { + this.tools = tools; + } + + public Object getToolChoice() { + return this.toolChoice; + } + + public void setToolChoice(Object toolChoice) { + this.toolChoice = toolChoice; + } + + public Boolean getLogprobs() { + return this.logprobs; + } + + public void setLogprobs(Boolean logprobs) { + this.logprobs = logprobs; + } + + public Integer getTopLogprobs() { + return this.topLogprobs; + } + + public void setTopLogprobs(Integer topLogprobs) { + this.topLogprobs = topLogprobs; + } + + @Override + @JsonIgnore + public Integer getTopK() { + return null; + } + + @Override + public Boolean getProxyToolCalls() { + return this.proxyToolCalls; + } + + public void setProxyToolCalls(Boolean proxyToolCalls) { + this.proxyToolCalls = proxyToolCalls; + } + + @Override + public Map getToolContext() { + return this.toolContext; + } + + @Override + public void setToolContext(Map toolContext) { + this.toolContext = toolContext; + } + + @Override + public DeepSeekChatOptions copy() { + return builder().model(this.model) + .maxTokens(this.maxTokens) + .temperature(this.temperature) + .topP(this.topP) + .presencePenalty(this.presencePenalty) + .frequencyPenalty(this.frequencyPenalty) + .stop(this.stop) + .tools(this.tools) + .toolChoice(this.toolChoice) + .functionCallbacks(this.functionCallbacks) + .functions(this.functions) + .proxyToolCalls(this.proxyToolCalls) + .toolContext(this.toolContext) + .build(); + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((this.model == null) ? 0 : this.model.hashCode()); + result = prime * result + ((this.frequencyPenalty == null) ? 0 : this.frequencyPenalty.hashCode()); + result = prime * result + ((this.maxTokens == null) ? 0 : this.maxTokens.hashCode()); + result = prime * result + ((this.presencePenalty == null) ? 0 : this.presencePenalty.hashCode()); + result = prime * result + ((this.stop == null) ? 0 : this.stop.hashCode()); + result = prime * result + ((this.temperature == null) ? 0 : this.temperature.hashCode()); + result = prime * result + ((this.topP == null) ? 0 : this.topP.hashCode()); + result = prime * result + ((this.proxyToolCalls == null) ? 0 : this.proxyToolCalls.hashCode()); + result = prime * result + ((this.toolContext == null) ? 0 : this.toolContext.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + DeepSeekChatOptions other = (DeepSeekChatOptions) obj; + if (this.model == null) { + if (other.model != null) { + return false; + } + } + else if (!this.model.equals(other.model)) { + return false; + } + if (this.frequencyPenalty == null) { + if (other.frequencyPenalty != null) { + return false; + } + } + else if (!this.frequencyPenalty.equals(other.frequencyPenalty)) { + return false; + } + if (this.maxTokens == null) { + if (other.maxTokens != null) { + return false; + } + } + else if (!this.maxTokens.equals(other.maxTokens)) { + return false; + } + if (this.presencePenalty == null) { + if (other.presencePenalty != null) { + return false; + } + } + else if (!this.presencePenalty.equals(other.presencePenalty)) { + return false; + } + if (this.stop == null) { + if (other.stop != null) { + return false; + } + } + else if (!this.stop.equals(other.stop)) { + return false; + } + if (this.temperature == null) { + if (other.temperature != null) { + return false; + } + } + else if (!this.temperature.equals(other.temperature)) { + return false; + } + if (this.topP == null) { + if (other.topP != null) { + return false; + } + } + else if (!this.topP.equals(other.topP)) { + return false; + } + if (this.proxyToolCalls == null) { + return other.proxyToolCalls == null; + } + else if (!this.proxyToolCalls.equals(other.proxyToolCalls)) { + return false; + } + if (this.toolContext == null) { + return other.toolContext == null; + } + else if (!this.toolContext.equals(other.toolContext)) { + return false; + } + return true; + } + + public static class Builder { + + private final DeepSeekChatOptions options = new DeepSeekChatOptions(); + + public Builder model(String model) { + this.options.model = model; + return this; + } + + public Builder model(DeepSeekApi.ChatModel model) { + this.options.model = model.getName(); + return this; + } + + public Builder maxTokens(Integer maxTokens) { + this.options.maxTokens = maxTokens; + return this; + } + + public Builder temperature(Double temperature) { + this.options.temperature = temperature; + return this; + } + + public Builder topP(Double topP) { + this.options.topP = topP; + return this; + } + + public Builder logprobs(Boolean logprobs) { + this.options.logprobs = logprobs; + return this; + } + + public Builder topLogprobs(Integer topLogprobs) { + this.options.topLogprobs = topLogprobs; + return this; + } + + public Builder presencePenalty(Double presencePenalty) { + this.options.presencePenalty = presencePenalty; + return this; + } + + public Builder responseFormat(ResponseFormat responseFormat) { + this.options.responseFormat = responseFormat; + return this; + } + + public Builder frequencyPenalty(Double frequencyPenalty) { + this.options.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder stop(List stop) { + this.options.stop = stop; + return this; + } + + public Builder tools(List tools) { + this.options.tools = tools; + return this; + } + + public Builder toolChoice(Object toolChoice) { + this.options.toolChoice = toolChoice; + return this; + } + + public Builder functionCallbacks(List functionCallbacks) { + this.options.functionCallbacks = functionCallbacks; + return this; + } + + public Builder functions(Set functionNames) { + Assert.notNull(functionNames, "Function names must not be null"); + this.options.functions = functionNames; + return this; + } + + public Builder function(String functionName) { + Assert.hasText(functionName, "Function name must not be empty"); + if (this.options.functions == null) { + this.options.functions = new HashSet<>(); + } + this.options.functions.add(functionName); + return this; + } + + public Builder proxyToolCalls(Boolean proxyToolCalls) { + this.options.proxyToolCalls = proxyToolCalls; + return this; + } + + public Builder toolContext(Map toolContext) { + if (this.options.toolContext == null) { + this.options.toolContext = toolContext; + } + else { + this.options.toolContext.putAll(toolContext); + } + return this; + } + + public DeepSeekChatOptions build() { + return this.options; + } + + } + + public static DeepSeekChatOptions fromOptions(DeepSeekChatOptions fromOptions) { + return builder().model(fromOptions.getModel()) + .frequencyPenalty(fromOptions.getFrequencyPenalty()) + .logprobs(fromOptions.getLogprobs()) + .topLogprobs(fromOptions.getTopLogprobs()) + .maxTokens(fromOptions.getMaxTokens()) + .presencePenalty(fromOptions.getPresencePenalty()) + .responseFormat(fromOptions.getResponseFormat()) + .stop(fromOptions.getStop()) + .temperature(fromOptions.getTemperature()) + .topP(fromOptions.getTopP()) + .tools(fromOptions.getTools()) + .toolChoice(fromOptions.getToolChoice()) + .functionCallbacks(fromOptions.getFunctionCallbacks()) + .functions(fromOptions.getFunctions()) + .proxyToolCalls(fromOptions.getProxyToolCalls()) + .toolContext(fromOptions.getToolContext()) + .build(); + } + +} diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/aot/DeepSeekRuntimeHints.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/aot/DeepSeekRuntimeHints.java new file mode 100644 index 00000000000..22d9ce8b56e --- /dev/null +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/aot/DeepSeekRuntimeHints.java @@ -0,0 +1,42 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.deepseek.aot; + +import org.springframework.ai.deepseek.api.DeepSeekApi; +import org.springframework.aot.hint.MemberCategory; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.RuntimeHintsRegistrar; +import org.springframework.lang.NonNull; +import org.springframework.lang.Nullable; + +import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; + +/** + * The DeepSeekRuntimeHints class is responsible for registering runtime hints for + * DeepSeek API classes. + * + * @author Geng Rong + */ +public class DeepSeekRuntimeHints implements RuntimeHintsRegistrar { + + @Override + public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) { + var mcs = MemberCategory.values(); + for (var tr : findJsonAnnotatedClassesInPackage(DeepSeekApi.class)) + hints.reflection().registerType(tr, mcs); + } + +} diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekApi.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekApi.java new file mode 100644 index 00000000000..d3c45448289 --- /dev/null +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekApi.java @@ -0,0 +1,1101 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.deepseek.api; + +import com.fasterxml.jackson.annotation.*; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import org.springframework.ai.model.ChatModelDescription; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Consumer; +import java.util.function.Predicate; + +import static org.springframework.ai.deepseek.api.common.DeepSeekConstants.*; + +/** + * Single class implementation of the DeepSeek Chat Completion API: + * https://platform.deepseek.com/api-docs/api/create-chat-completion + * + * @author Geng Rong + */ +public class DeepSeekApi { + + public static final DeepSeekApi.ChatModel DEFAULT_CHAT_MODEL = ChatModel.DEEPSEEK_REASONER; + + private static final Predicate SSE_DONE_PREDICATE = "[DONE]"::equals; + + private final String completionsPath; + + private final String betaFeaturePath; + + private final RestClient restClient; + + private final WebClient webClient; + + private DeepSeekStreamFunctionCallingHelper chunkMerger = new DeepSeekStreamFunctionCallingHelper(); + + /** + * Create a new chat completion api with base URL set to https://api.deepseek.com + * @param apiKey DeepSeek apiKey. + */ + public DeepSeekApi(String apiKey) { + this(DEFAULT_BASE_URL, apiKey); + } + + /** + * Create a new chat completion api. + * @param baseUrl api base URL. + * @param apiKey DeepSeek apiKey. + */ + public DeepSeekApi(String baseUrl, String apiKey) { + this(baseUrl, apiKey, RestClient.builder(), WebClient.builder()); + } + + /** + * Create a new chat completion api. + * @param baseUrl api base URL. + * @param apiKey DeepSeek apiKey. + * @param restClientBuilder RestClient builder. + * @param webClientBuilder WebClient builder. + */ + public DeepSeekApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder, + WebClient.Builder webClientBuilder) { + this(baseUrl, apiKey, restClientBuilder, webClientBuilder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER); + } + + /** + * Create a new chat completion api. + * @param baseUrl api base URL. + * @param apiKey DeepSeek apiKey. + * @param restClientBuilder RestClient builder. + * @param webClientBuilder WebClient builder. + * @param responseErrorHandler Response error handler. + */ + public DeepSeekApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder, + WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) { + this(baseUrl, apiKey, DEFAULT_COMPLETIONS_PATH, DEFAULT_BETA_PATH, restClientBuilder, webClientBuilder, + responseErrorHandler); + } + + /** + * Create a new chat completion api. + * @param baseUrl api base URL. + * @param apiKey DeepSeek apiKey. + * @param completionsPath the path to the chat completions endpoint. + * @param betaFeaturePath the path to the beta feature endpoint. + * @param restClientBuilder RestClient builder. + * @param webClientBuilder WebClient builder. + * @param responseErrorHandler Response error handler. + */ + public DeepSeekApi(String baseUrl, String apiKey, String completionsPath, String betaFeaturePath, + RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, + ResponseErrorHandler responseErrorHandler) { + this(baseUrl, apiKey, CollectionUtils.toMultiValueMap(Map.of()), completionsPath, betaFeaturePath, + restClientBuilder, webClientBuilder, responseErrorHandler); + } + + /** + * Create a new chat completion api. + * @param baseUrl api base URL. + * @param apiKey DeepSeek apiKey. + * @param headers the http headers to use. + * @param completionsPath the path to the chat completions endpoint. + * @param betaFeaturePath the path to the beta feature endpoint. + * @param restClientBuilder RestClient builder. + * @param webClientBuilder WebClient builder. + * @param responseErrorHandler Response error handler. + */ + public DeepSeekApi(String baseUrl, String apiKey, MultiValueMap headers, String completionsPath, + String betaFeaturePath, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, + ResponseErrorHandler responseErrorHandler) { + + Assert.hasText(completionsPath, "Completions Path must not be null"); + Assert.hasText(betaFeaturePath, "Beta feature path must not be null"); + Assert.notNull(headers, "Headers must not be null"); + + this.completionsPath = completionsPath; + this.betaFeaturePath = betaFeaturePath; + // @formatter:off + Consumer finalHeaders = h -> { + h.setBearerAuth(apiKey); + h.setContentType(MediaType.APPLICATION_JSON); + h.addAll(headers); + }; + this.restClient = restClientBuilder.baseUrl(baseUrl) + .defaultHeaders(finalHeaders) + .defaultStatusHandler(responseErrorHandler) + .build(); + + this.webClient = webClientBuilder + .baseUrl(baseUrl) + .defaultHeaders(finalHeaders) + .build(); // @formatter:on + } + + /** + * Creates a model response for the given chat conversation. + * @param chatRequest The chat completion request. + * @return Entity response with {@link ChatCompletion} as a body and HTTP status code + * and headers. + */ + public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { + + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); + + return this.restClient.post() + .uri(this.completionsPath) + .body(chatRequest) + .retrieve() + .toEntity(ChatCompletion.class); + } + + /** + * Creates a streaming chat response for the given chat conversation. + * @param chatRequest The chat completion request. Must have the stream property set + * to true. + * @return Returns a {@link Flux} stream from chat completion chunks. + */ + public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { + return chatCompletionStream(chatRequest, new LinkedMultiValueMap<>()); + } + + /** + * Creates a streaming chat response for the given chat conversation. + * @param chatRequest The chat completion request. Must have the stream property set + * to true. + * @param additionalHttpHeader Optional, additional HTTP headers to be added to the + * request. + * @return Returns a {@link Flux} stream from chat completion chunks. + */ + public Flux chatCompletionStream(ChatCompletionRequest chatRequest, + MultiValueMap additionalHttpHeader) { + + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); + + AtomicBoolean isInsideTool = new AtomicBoolean(false); + + return this.webClient.post() + .uri(this.completionsPath) + .headers(headers -> headers.addAll(additionalHttpHeader)) + .body(Mono.just(chatRequest), ChatCompletionRequest.class) + .retrieve() + .bodyToFlux(String.class) + // cancels the flux stream after the "[DONE]" is received. + .takeUntil(SSE_DONE_PREDICATE) + // filters out the "[DONE]" message. + .filter(SSE_DONE_PREDICATE.negate()) + .map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class)) + // Detect is the chunk is part of a streaming function call. + .map(chunk -> { + if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) { + isInsideTool.set(true); + } + return chunk; + }) + // Group all chunks belonging to the same function call. + // Flux -> Flux> + .windowUntil(chunk -> { + if (isInsideTool.get() && this.chunkMerger.isStreamingToolFunctionCallFinish(chunk)) { + isInsideTool.set(false); + return true; + } + return !isInsideTool.get(); + }) + // Merging the window chunks into a single chunk. + // Reduce the inner Flux window into a single + // Mono, + // Flux> -> Flux> + .concatMapIterable(window -> { + Mono monoChunk = window.reduce( + new ChatCompletionChunk(null, null, null, null, null, null, null, null), + (previous, current) -> this.chunkMerger.merge(previous, current)); + return List.of(monoChunk); + }) + // Flux> -> Flux + .flatMap(mono -> mono); + } + + /** + * DeepSeek Chat Completion + * Models + */ + public enum ChatModel implements ChatModelDescription { + + /** + * The backend model of deepseek-chat has been updated to DeepSeek-V3, you can + * access DeepSeek-V3 without modification to the model name. The open-source + * DeepSeek-V3 model supports 128K context window, and DeepSeek-V3 on API/Web + * supports 64K context window. Context window: 64k tokens + */ + DEEPSEEK_CHAT("deepseek-chat"), + + /** + * deepseek-reasoner is a reasoning model developed by DeepSeek. Before delivering + * the final answer, the model first generates a Chain of Thought (CoT) to enhance + * the accuracy of its responses. Our API provides users with access to the CoT + * content generated by deepseek-reasoner, enabling them to view, display, and + * distill it. + */ + DEEPSEEK_REASONER("deepseek-reasoner"); + + public final String value; + + ChatModel(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + + @Override + public String getName() { + return value; + } + + } + + /** + * The reason the model stopped generating tokens. + */ + public enum ChatCompletionFinishReason { + + /** + * The model hit a natural stop point or a provided stop sequence. + */ + @JsonProperty("stop") + STOP, + /** + * The maximum number of tokens specified in the request was reached. + */ + @JsonProperty("length") + LENGTH, + /** + * The content was omitted due to a flag from our content filters. + */ + @JsonProperty("content_filter") + CONTENT_FILTER, + /** + * The model called a tool. + */ + @JsonProperty("tool_calls") + TOOL_CALLS, + /** + * Only for compatibility with Mistral AI API. + */ + @JsonProperty("tool_call") + TOOL_CALL + + } + + /** + * Represents a tool the model may call. Currently, only functions are supported as a + * tool. + */ + @JsonInclude(Include.NON_NULL) + public static class FunctionTool { + + /** + * The type of the tool. Currently, only 'function' is supported. + */ + @JsonProperty("type") + private Type type = Type.FUNCTION; + + /** + * The function definition. + */ + @JsonProperty("function") + private Function function; + + public FunctionTool() { + + } + + /** + * Create a tool of type 'function' and the given function definition. + * @param type the tool type + * @param function function definition + */ + public FunctionTool(Type type, Function function) { + this.type = type; + this.function = function; + } + + /** + * Create a tool of type 'function' and the given function definition. + * @param function function definition. + */ + public FunctionTool(Function function) { + this(Type.FUNCTION, function); + } + + public Type getType() { + return this.type; + } + + public Function getFunction() { + return this.function; + } + + public void setType(Type type) { + this.type = type; + } + + public void setFunction(Function function) { + this.function = function; + } + + /** + * Create a tool of type 'function' and the given function definition. + */ + public enum Type { + + /** + * Function tool type. + */ + @JsonProperty("function") + FUNCTION + + } + + /** + * Function definition. + */ + @JsonInclude(Include.NON_NULL) + public static class Function { + + @JsonProperty("description") + private String description; + + @JsonProperty("name") + private String name; + + @JsonProperty("parameters") + private Map parameters; + + @JsonProperty("strict") + Boolean strict; + + @JsonIgnore + private String jsonSchema; + + /** + * NOTE: Required by Jackson, JSON deserialization! + */ + @SuppressWarnings("unused") + private Function() { + } + + /** + * Create tool function definition. + * @param description A description of what the function does, used by the + * model to choose when and how to call the function. + * @param name The name of the function to be called. Must be a-z, A-Z, 0-9, + * or contain underscores and dashes, with a maximum length of 64. + * @param parameters The parameters the functions accepts, described as a JSON + * Schema object. To describe a function that accepts no parameters, provide + * the value {"type": "object", "properties": {}}. + * @param strict Whether to enable strict schema adherence when generating the + * function call. If set to true, the model will follow the exact schema + * defined in the parameters field. Only a subset of JSON Schema is supported + * when strict is true. + */ + public Function(String description, String name, Map parameters, Boolean strict) { + this.description = description; + this.name = name; + this.parameters = parameters; + this.strict = strict; + } + + /** + * Create tool function definition. + * @param description tool function description. + * @param name tool function name. + * @param jsonSchema tool function schema as json. + */ + public Function(String description, String name, String jsonSchema) { + this(description, name, ModelOptionsUtils.jsonToMap(jsonSchema), null); + } + + public String getDescription() { + return this.description; + } + + public String getName() { + return this.name; + } + + public Map getParameters() { + return this.parameters; + } + + public void setDescription(String description) { + this.description = description; + } + + public void setName(String name) { + this.name = name; + } + + public void setParameters(Map parameters) { + this.parameters = parameters; + } + + public Boolean getStrict() { + return this.strict; + } + + public void setStrict(Boolean strict) { + this.strict = strict; + } + + public String getJsonSchema() { + return this.jsonSchema; + } + + public void setJsonSchema(String jsonSchema) { + this.jsonSchema = jsonSchema; + if (jsonSchema != null) { + this.parameters = ModelOptionsUtils.jsonToMap(jsonSchema); + } + } + + } + + } + + /** + * Creates a model response for the given chat conversation. + * + * @param messages A list of messages comprising the conversation so far. + * @param model ID of the model to use. + * @param frequencyPenalty Number between -2.0 and 2.0. Positive values penalize new + * tokens based on their existing frequency in the text so far, decreasing the model's + * likelihood to repeat the same line verbatim. + * @param maxTokens The maximum number of tokens that can be generated in the chat + * completion. This value can be used to control costs for text generated via API. + * This value is now deprecated in favor of max_completion_tokens, and is not + * compatible with o1 series models. + * @param presencePenalty Number between -2.0 and 2.0. Positive values penalize new + * tokens based on whether they appear in the text so far, increasing the model's + * likelihood to talk about new topics. + * @param responseFormat An object specifying the format that the model must output. + * Setting to { "type": "json_object" } enables JSON mode, which guarantees the + * message the model generates is valid JSON. + * @param stop A string or a list containing up to 4 strings, upon encountering these + * words, the API will cease generating more tokens. + * @param stream If set, partial message deltas will be sent.Tokens will be sent as + * data-only server-sent events as they become available, with the stream terminated + * by a data: [DONE] message. + * @param temperature What sampling temperature to use, between 0 and 2. Higher values + * like 0.8 will make the output more random, while lower values like 0.2 will make it + * more focused and deterministic. We generally recommend altering this or top_p but + * not both. + * @param topP An alternative to sampling with temperature, called nucleus sampling, + * where the model considers the results of the tokens with top_p probability mass. So + * 0.1 means only the tokens comprising the top 10% probability mass are considered. + * We generally recommend altering this or temperature but not both. + * @param logprobs Whether to return log probabilities of the output tokens or not. If + * true, returns the log probabilities of each output token returned in the content of + * message. + * @param topLogprobs An integer between 0 and 20 specifying the number of most likely + * tokens to return at each token position, each with an associated log probability. + * logprobs must be set to true if this parameter is used. + * @param tools A list of tools the model may call. Currently, only functions are + * supported as a tool. Use this to provide a list of functions the model may generate + * JSON inputs for. + * @param toolChoice Controls which (if any) function is called by the model. none + * means the model will not call a function and instead generates a message. auto + * means the model can pick between generating a message or calling a function. + * Specifying a particular function via {"type: "function", "function": {"name": + * "my_function"}} forces the model to call that function. none is the default when no + * functions are present. auto is the default if functions are present. Use the + * {@link ToolChoiceBuilder} to create the tool choice value. + */ + @JsonInclude(Include.NON_NULL) + public record ChatCompletionRequest(// @formatter:off + @JsonProperty("messages") List messages, + @JsonProperty("model") String model, + @JsonProperty("frequency_penalty") Double frequencyPenalty, + @JsonProperty("max_tokens") Integer maxTokens, // Use maxCompletionTokens instead + @JsonProperty("presence_penalty") Double presencePenalty, + @JsonProperty("response_format") ResponseFormat responseFormat, + @JsonProperty("stop") List stop, + @JsonProperty("stream") Boolean stream, + @JsonProperty("temperature") Double temperature, + @JsonProperty("top_p") Double topP, + @JsonProperty("logprobs") Boolean logprobs, + @JsonProperty("top_logprobs") Integer topLogprobs, + @JsonProperty("tools") List tools, + @JsonProperty("tool_choice") Object toolChoice) + { + + + /** + * Shortcut constructor for a chat completion request with the given messages for streaming. + * + * @param messages A list of messages comprising the conversation so far. + * @param stream If set, partial message deltas will be sent.Tokens will be sent as data-only server-sent events + * as they become available, with the stream terminated by a data: [DONE] message. + */ + public ChatCompletionRequest(List messages, Boolean stream) { + this(messages, null, null, null, null, null, + null, stream, null, null, null, null, null, null); + } + + /** + * Shortcut constructor for a chat completion request with the given messages, model and temperature. + * + * @param messages A list of messages comprising the conversation so far. + * @param model ID of the model to use. + * @param temperature What sampling temperature to use, between 0 and 1. + */ + public ChatCompletionRequest(List messages, String model, Double temperature) { + this(messages, model, null, + null, null, null, null, false, temperature, null, + null, null, null,null); + } + + /** + * Shortcut constructor for a chat completion request with the given messages, model, temperature and control for streaming. + * + * @param messages A list of messages comprising the conversation so far. + * @param model ID of the model to use. + * @param temperature What sampling temperature to use, between 0 and 1. + * @param stream If set, partial message deltas will be sent.Tokens will be sent as data-only server-sent events + * as they become available, with the stream terminated by a data: [DONE] message. + */ + public ChatCompletionRequest(List messages, String model, Double temperature, boolean stream) { + this(messages, model, null, + null, null, null, null, stream, temperature, null, + null, null, null,null); + } + + /** + * Shortcut constructor for a chat completion request with the given messages, model, tools and tool choice. + * Streaming is set to false, temperature to 0.8 and all other parameters are null. + * + * @param messages A list of messages comprising the conversation so far. + * @param model ID of the model to use. + * @param tools A list of tools the model may call. Currently, only functions are supported as a tool. + * @param toolChoice Controls which (if any) function is called by the model. + */ + public ChatCompletionRequest(List messages, String model, + List tools, Object toolChoice) { + this(messages, model, null, + null, null, null, null, false, 0.8, null, + null, null, tools, toolChoice ); + } + + /** + * Helper factory that creates a tool_choice of type 'none', 'auto' or selected function by name. + */ + public static class ToolChoiceBuilder { + /** + * Model can pick between generating a message or calling a function. + */ + public static final String AUTO = "auto"; + /** + * Model will not call a function and instead generates a message + */ + public static final String NONE = "none"; + + /** + * Specifying a particular function forces the model to call that function. + */ + public static Object FUNCTION(String functionName) { + return Map.of("type", "function", "function", Map.of("name", functionName)); + } + } + + /** + * Parameters for audio output. Required when audio output is requested with outputModalities: ["audio"]. + * @param voice Specifies the voice type. + * @param format Specifies the output audio format. + */ + @JsonInclude(Include.NON_NULL) + public record AudioParameters( + @JsonProperty("voice") Voice voice, + @JsonProperty("format") AudioResponseFormat format) { + + /** + * Specifies the voice type. + */ + public enum Voice { + /** Alloy voice */ + @JsonProperty("alloy") ALLOY, + /** Echo voice */ + @JsonProperty("echo") ECHO, + /** Fable voice */ + @JsonProperty("fable") FABLE, + /** Onyx voice */ + @JsonProperty("onyx") ONYX, + /** Nova voice */ + @JsonProperty("nova") NOVA, + /** Shimmer voice */ + @JsonProperty("shimmer") SHIMMER + } + + /** + * Specifies the output audio format. + */ + public enum AudioResponseFormat { + /** MP3 format */ + @JsonProperty("mp3") MP3, + /** FLAC format */ + @JsonProperty("flac") FLAC, + /** OPUS format */ + @JsonProperty("opus") OPUS, + /** PCM16 format */ + @JsonProperty("pcm16") PCM16, + /** WAV format */ + @JsonProperty("wav") WAV + } + } + + /** + * @param includeUsage If set, an additional chunk will be streamed + * before the data: [DONE] message. The usage field on this chunk + * shows the token usage statistics for the entire request, and + * the choices field will always be an empty array. All other chunks + * will also include a usage field, but with a null value. + */ + @JsonInclude(Include.NON_NULL) + public record StreamOptions( + @JsonProperty("include_usage") Boolean includeUsage) { + + public static StreamOptions INCLUDE_USAGE = new StreamOptions(true); + } + } // @formatter:on + + /** + * Message comprising the conversation. + * + * @param rawContent The contents of the message. Can be either a {@link MediaContent} + * or a {@link String}. The response message content is always a {@link String}. + * @param role The role of the messages author. Could be one of the {@link Role} + * types. + * @param name An optional name for the participant. Provides the model information to + * differentiate between participants of the same role. In case of Function calling, + * the name is the function name that the message is responding to. + * @param toolCallId Tool call that this message is responding to. Only applicable for + * the {@link Role#TOOL} role and null otherwise. + * @param toolCalls The tool calls generated by the model, such as function calls. + * Applicable only for {@link Role#ASSISTANT} role and null otherwise. + */ + @JsonInclude(Include.NON_NULL) + public record ChatCompletionMessage(// @formatter:off + @JsonProperty("content") Object rawContent, + @JsonProperty("role") Role role, + @JsonProperty("name") String name, + @JsonProperty("tool_call_id") String toolCallId, + @JsonProperty("tool_calls") + @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) List toolCalls + ) { // @formatter:on + + /** + * Create a chat completion message with the given content and role. All other + * fields are null. + * @param content The contents of the message. + * @param role The role of the author of this message. + */ + public ChatCompletionMessage(Object content, Role role) { + this(content, role, null, null, null); + + } + + /** + * Get message content as String. + */ + public String content() { + if (this.rawContent == null) { + return null; + } + if (this.rawContent instanceof String text) { + return text; + } + throw new IllegalStateException("The content is not a string!"); + } + + /** + * The role of the author of this message. + */ + public enum Role { + + /** + * System message. + */ + @JsonProperty("system") + SYSTEM, + /** + * User message. + */ + @JsonProperty("user") + USER, + /** + * Assistant message. + */ + @JsonProperty("assistant") + ASSISTANT, + /** + * Tool message. + */ + @JsonProperty("tool") + TOOL + + } + + /** + * An array of content parts with a defined type. Each MediaContent can be of + * either "text", "image_url", or "input_audio" type. Only one option allowed. + * + * @param type Content type, each can be of type text or image_url. + * @param text The text content of the message. + * @param imageUrl The image content of the message. You can pass multiple images + * by adding multiple image_url content parts. Image input is only supported when + * using the gpt-4-visual-preview model. + * @param inputAudio Audio content part. + */ + @JsonInclude(Include.NON_NULL) + public record MediaContent(// @formatter:off + @JsonProperty("type") String type, + @JsonProperty("text") String text, + @JsonProperty("image_url") ImageUrl imageUrl, + @JsonProperty("input_audio") InputAudio inputAudio) { // @formatter:on + + /** + * Shortcut constructor for a text content. + * @param text The text content of the message. + */ + public MediaContent(String text) { + this("text", text, null, null); + } + + /** + * Shortcut constructor for an image content. + * @param imageUrl The image content of the message. + */ + public MediaContent(ImageUrl imageUrl) { + this("image_url", null, imageUrl, null); + } + + /** + * Shortcut constructor for an audio content. + * @param inputAudio The audio content of the message. + */ + public MediaContent(InputAudio inputAudio) { + this("input_audio", null, null, inputAudio); + } + + /** + * @param data Base64 encoded audio data. + * @param format The format of the encoded audio data. Currently supports + * "wav" and "mp3". + */ + @JsonInclude(Include.NON_NULL) + public record InputAudio(// @formatter:off + @JsonProperty("data") String data, + @JsonProperty("format") Format format) { + + public enum Format { + /** MP3 audio format */ + @JsonProperty("mp3") MP3, + /** WAV audio format */ + @JsonProperty("wav") WAV + } // @formatter:on + } + + /** + * Shortcut constructor for an image content. + * + * @param url Either a URL of the image or the base64 encoded image data. The + * base64 encoded image data must have a special prefix in the following + * format: "data:{mimetype};base64,{base64-encoded-image-data}". + * @param detail Specifies the detail level of the image. + */ + @JsonInclude(Include.NON_NULL) + public record ImageUrl(@JsonProperty("url") String url, @JsonProperty("detail") String detail) { + + public ImageUrl(String url) { + this(url, null); + } + + } + + } + + /** + * The relevant tool call. + * + * @param index The index of the tool call in the list of tool calls. Required in + * case of streaming. + * @param id The ID of the tool call. This ID must be referenced when you submit + * the tool outputs in using the Submit tool outputs to run endpoint. + * @param type The type of tool call the output is required for. For now, this is + * always function. + * @param function The function definition. + */ + @JsonInclude(Include.NON_NULL) + public record ToolCall(// @formatter:off + @JsonProperty("index") Integer index, + @JsonProperty("id") String id, + @JsonProperty("type") String type, + @JsonProperty("function") ChatCompletionFunction function) { // @formatter:on + + public ToolCall(String id, String type, ChatCompletionFunction function) { + this(null, id, type, function); + } + + } + + /** + * The function definition. + * + * @param name The name of the function. + * @param arguments The arguments that the model expects you to pass to the + * function. + */ + @JsonInclude(Include.NON_NULL) + public record ChatCompletionFunction(// @formatter:off + @JsonProperty("name") String name, + @JsonProperty("arguments") String arguments) { // @formatter:on + } + } + + /** + * Represents a chat completion response returned by model, based on the provided + * input. + * + * @param id A unique identifier for the chat completion. + * @param choices A list of chat completion choices. Can be more than one if n is + * greater than 1. + * @param created The Unix timestamp (in seconds) of when the chat completion was + * created. + * @param model The model used for the chat completion. + * @param systemFingerprint This fingerprint represents the backend configuration that + * the model runs with. Can be used in conjunction with the seed request parameter to + * understand when backend changes have been made that might impact determinism. + * @param object The object type, which is always chat.completion. + * @param usage Usage statistics for the completion request. + */ + @JsonInclude(Include.NON_NULL) + public record ChatCompletion(// @formatter:off + @JsonProperty("id") String id, + @JsonProperty("choices") List choices, + @JsonProperty("created") Long created, + @JsonProperty("model") String model, + @JsonProperty("system_fingerprint") String systemFingerprint, + @JsonProperty("object") String object, + @JsonProperty("usage") Usage usage + ) { // @formatter:on + + /** + * Chat completion choice. + * + * @param finishReason The reason the model stopped generating tokens. + * @param index The index of the choice in the list of choices. + * @param message A chat completion message generated by the model. + * @param logprobs Log probability information for the choice. + */ + @JsonInclude(Include.NON_NULL) + public record Choice(// @formatter:off + @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason, + @JsonProperty("index") Integer index, + @JsonProperty("message") ChatCompletionMessage message, + @JsonProperty("logprobs") LogProbs logprobs) { // @formatter:on + } + + } + + /** + * Log probability information for the choice. + * + * @param content A list of message content tokens with log probability information. + * @param refusal A list of message refusal tokens with log probability information. + */ + @JsonInclude(Include.NON_NULL) + public record LogProbs(@JsonProperty("content") List content, + @JsonProperty("refusal") List refusal) { + + /** + * Message content tokens with log probability information. + * + * @param token The token. + * @param logprob The log probability of the token. + * @param probBytes A list of integers representing the UTF-8 bytes representation + * of the token. Useful in instances where characters are represented by multiple + * tokens and their byte representations must be combined to generate the correct + * text representation. Can be null if there is no bytes representation for the + * token. + * @param topLogprobs List of the most likely tokens and their log probability, at + * this token position. In rare cases, there may be fewer than the number of + * requested top_logprobs returned. + */ + @JsonInclude(Include.NON_NULL) + public record Content(// @formatter:off + @JsonProperty("token") String token, + @JsonProperty("logprob") Float logprob, + @JsonProperty("bytes") List probBytes, + @JsonProperty("top_logprobs") List topLogprobs) { // @formatter:on + + /** + * The most likely tokens and their log probability, at this token position. + * + * @param token The token. + * @param logprob The log probability of the token. + * @param probBytes A list of integers representing the UTF-8 bytes + * representation of the token. Useful in instances where characters are + * represented by multiple tokens and their byte representations must be + * combined to generate the correct text representation. Can be null if there + * is no bytes representation for the token. + */ + @JsonInclude(Include.NON_NULL) + public record TopLogProbs(// @formatter:off + @JsonProperty("token") String token, + @JsonProperty("logprob") Float logprob, + @JsonProperty("bytes") List probBytes) { // @formatter:on + } + + } + + } + + // Embeddings API + + /** + * Usage statistics for the completion request. + * + * @param completionTokens Number of tokens in the generated completion. Only + * applicable for completion requests. + * @param promptTokens Number of tokens in the prompt. + * @param totalTokens Total number of tokens used in the request (prompt + + * completion). + * @param promptTokensDetails Breakdown of tokens used in the prompt. + * @param completionTokenDetails Breakdown of tokens used in a completion. + */ + @JsonInclude(Include.NON_NULL) + public record Usage(// @formatter:off + @JsonProperty("completion_tokens") Integer completionTokens, + @JsonProperty("prompt_tokens") Integer promptTokens, + @JsonProperty("total_tokens") Integer totalTokens, + @JsonProperty("prompt_tokens_details") PromptTokensDetails promptTokensDetails, + @JsonProperty("completion_tokens_details") CompletionTokenDetails completionTokenDetails) { // @formatter:on + + public Usage(Integer completionTokens, Integer promptTokens, Integer totalTokens) { + this(completionTokens, promptTokens, totalTokens, null, null); + } + + /** + * Breakdown of tokens used in the prompt + * + * @param audioTokens Audio input tokens present in the prompt. + * @param cachedTokens Cached tokens present in the prompt. + */ + @JsonInclude(Include.NON_NULL) + public record PromptTokensDetails(// @formatter:off + @JsonProperty("audio_tokens") Integer audioTokens, + @JsonProperty("cached_tokens") Integer cachedTokens) { // @formatter:on + } + + /** + * Breakdown of tokens used in a completion. + * + * @param reasoningTokens Number of tokens generated by the model for reasoning. + * @param acceptedPredictionTokens Number of tokens generated by the model for + * accepted predictions. + * @param audioTokens Number of tokens generated by the model for audio. + * @param rejectedPredictionTokens Number of tokens generated by the model for + * rejected predictions. + */ + @JsonInclude(Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record CompletionTokenDetails(// @formatter:off + @JsonProperty("reasoning_tokens") Integer reasoningTokens, + @JsonProperty("accepted_prediction_tokens") Integer acceptedPredictionTokens, + @JsonProperty("audio_tokens") Integer audioTokens, + @JsonProperty("rejected_prediction_tokens") Integer rejectedPredictionTokens) { // @formatter:on + } + } + + /** + * Represents a streamed chunk of a chat completion response returned by model, based + * on the provided input. + * + * @param id A unique identifier for the chat completion. Each chunk has the same ID. + * @param choices A list of chat completion choices. Can be more than one if n is + * greater than 1. + * @param created The Unix timestamp (in seconds) of when the chat completion was + * created. Each chunk has the same timestamp. + * @param model The model used for the chat completion. + * @param serviceTier The service tier used for processing the request. This field is + * only included if the service_tier parameter is specified in the request. + * @param systemFingerprint This fingerprint represents the backend configuration that + * the model runs with. Can be used in conjunction with the seed request parameter to + * understand when backend changes have been made that might impact determinism. + * @param object The object type, which is always 'chat.completion.chunk'. + * @param usage Usage statistics for the completion request. Present in the last chunk + * only if the StreamOptions.includeUsage is set to true. + */ + @JsonInclude(Include.NON_NULL) + public record ChatCompletionChunk(// @formatter:off + @JsonProperty("id") String id, + @JsonProperty("choices") List choices, + @JsonProperty("created") Long created, + @JsonProperty("model") String model, + @JsonProperty("service_tier") String serviceTier, + @JsonProperty("system_fingerprint") String systemFingerprint, + @JsonProperty("object") String object, + @JsonProperty("usage") Usage usage) { // @formatter:on + + /** + * Chat completion choice. + * + * @param finishReason The reason the model stopped generating tokens. + * @param index The index of the choice in the list of choices. + * @param delta A chat completion delta generated by streamed model responses. + * @param logprobs Log probability information for the choice. + */ + @JsonInclude(Include.NON_NULL) + public record ChunkChoice(// @formatter:off + @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason, + @JsonProperty("index") Integer index, + @JsonProperty("delta") ChatCompletionMessage delta, + @JsonProperty("logprobs") LogProbs logprobs) { // @formatter:on + + } + + } + +} diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekStreamFunctionCallingHelper.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekStreamFunctionCallingHelper.java new file mode 100644 index 00000000000..7fe854428f4 --- /dev/null +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekStreamFunctionCallingHelper.java @@ -0,0 +1,176 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.deepseek.api; + +import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionChunk; +import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionChunk.ChunkChoice; +import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionFinishReason; +import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage; +import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage.ChatCompletionFunction; +import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage.Role; +import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage.ToolCall; +import org.springframework.util.CollectionUtils; + +import java.util.ArrayList; +import java.util.List; + +/** + * Helper class to support Streaming function calling. It can merge the streamed + * ChatCompletionChunk in case of function calling message. + * + * @author Geng Rong + */ +public class DeepSeekStreamFunctionCallingHelper { + + public ChatCompletionChunk merge(ChatCompletionChunk previous, ChatCompletionChunk current) { + + if (previous == null) { + return current; + } + + String id = (current.id() != null ? current.id() : previous.id()); + Long created = (current.created() != null ? current.created() : previous.created()); + String model = (current.model() != null ? current.model() : previous.model()); + String serviceTier = (current.serviceTier() != null ? current.serviceTier() : previous.serviceTier()); + String systemFingerprint = (current.systemFingerprint() != null ? current.systemFingerprint() + : previous.systemFingerprint()); + String object = (current.object() != null ? current.object() : previous.object()); + DeepSeekApi.Usage usage = (current.usage() != null ? current.usage() : previous.usage()); + + ChunkChoice previousChoice0 = (CollectionUtils.isEmpty(previous.choices()) ? null : previous.choices().get(0)); + ChunkChoice currentChoice0 = (CollectionUtils.isEmpty(current.choices()) ? null : current.choices().get(0)); + + ChunkChoice choice = merge(previousChoice0, currentChoice0); + List chunkChoices = choice == null ? List.of() : List.of(choice); + return new ChatCompletionChunk(id, chunkChoices, created, model, serviceTier, systemFingerprint, object, usage); + } + + private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) { + if (previous == null) { + return current; + } + + ChatCompletionFinishReason finishReason = (current.finishReason() != null ? current.finishReason() + : previous.finishReason()); + Integer index = (current.index() != null ? current.index() : previous.index()); + + ChatCompletionMessage message = merge(previous.delta(), current.delta()); + + DeepSeekApi.LogProbs logprobs = (current.logprobs() != null ? current.logprobs() : previous.logprobs()); + return new ChunkChoice(finishReason, index, message, logprobs); + } + + private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompletionMessage current) { + String content = (current.content() != null ? current.content() + : "" + ((previous.content() != null) ? previous.content() : "")); + Role role = (current.role() != null ? current.role() : previous.role()); + role = (role != null ? role : Role.ASSISTANT); // default to ASSISTANT (if null + String name = (current.name() != null ? current.name() : previous.name()); + String toolCallId = (current.toolCallId() != null ? current.toolCallId() : previous.toolCallId()); + + List toolCalls = new ArrayList<>(); + ToolCall lastPreviousTooCall = null; + if (previous.toolCalls() != null) { + lastPreviousTooCall = previous.toolCalls().get(previous.toolCalls().size() - 1); + if (previous.toolCalls().size() > 1) { + toolCalls.addAll(previous.toolCalls().subList(0, previous.toolCalls().size() - 1)); + } + } + if (current.toolCalls() != null) { + if (current.toolCalls().size() > 1) { + throw new IllegalStateException("Currently only one tool call is supported per message!"); + } + var currentToolCall = current.toolCalls().iterator().next(); + if (currentToolCall.id() != null) { + if (lastPreviousTooCall != null) { + toolCalls.add(lastPreviousTooCall); + } + toolCalls.add(currentToolCall); + } + else { + toolCalls.add(merge(lastPreviousTooCall, currentToolCall)); + } + } + else { + if (lastPreviousTooCall != null) { + toolCalls.add(lastPreviousTooCall); + } + } + return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls); + } + + private ToolCall merge(ToolCall previous, ToolCall current) { + if (previous == null) { + return current; + } + String id = (current.id() != null ? current.id() : previous.id()); + String type = (current.type() != null ? current.type() : previous.type()); + ChatCompletionFunction function = merge(previous.function(), current.function()); + return new ToolCall(id, type, function); + } + + private ChatCompletionFunction merge(ChatCompletionFunction previous, ChatCompletionFunction current) { + if (previous == null) { + return current; + } + String name = (current.name() != null ? current.name() : previous.name()); + StringBuilder arguments = new StringBuilder(); + if (previous.arguments() != null) { + arguments.append(previous.arguments()); + } + if (current.arguments() != null) { + arguments.append(current.arguments()); + } + return new ChatCompletionFunction(name, arguments.toString()); + } + + /** + * @param chatCompletion the ChatCompletionChunk to check + * @return true if the ChatCompletionChunk is a streaming tool function call. + */ + public boolean isStreamingToolFunctionCall(ChatCompletionChunk chatCompletion) { + + if (chatCompletion == null || CollectionUtils.isEmpty(chatCompletion.choices())) { + return false; + } + + var choice = chatCompletion.choices().get(0); + if (choice == null || choice.delta() == null) { + return false; + } + return !CollectionUtils.isEmpty(choice.delta().toolCalls()); + } + + /** + * @param chatCompletion the ChatCompletionChunk to check + * @return true if the ChatCompletionChunk is a streaming tool function call and it is + * the last one. + */ + public boolean isStreamingToolFunctionCallFinish(ChatCompletionChunk chatCompletion) { + + if (chatCompletion == null || CollectionUtils.isEmpty(chatCompletion.choices())) { + return false; + } + + var choice = chatCompletion.choices().get(0); + if (choice == null || choice.delta() == null) { + return false; + } + return choice.finishReason() == ChatCompletionFinishReason.TOOL_CALLS; + } + +} diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/ResponseFormat.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/ResponseFormat.java new file mode 100644 index 00000000000..826675545fa --- /dev/null +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/ResponseFormat.java @@ -0,0 +1,126 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.deepseek.api; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Objects; + +/** + * An object specifying the format that the model must output. Setting to { "type": + * "json_object" } enables JSON Output, which guarantees the message the model generates + * is valid JSON. + *

+ * Important: When using JSON Output, you must also instruct the model to produce JSON + * yourself via a system or user message. Without this, the model may generate an unending + * stream of whitespace until the generation reaches the token limit, resulting in a + * long-running and seemingly "stuck" request. Also note that the message content may be + * partially cut off if finish_reason="length", which indicates the generation exceeded + * max_tokens or the conversation exceeded the max context length. + *

+ * References: + * DeepSeek API - + * Create Chat Completion + * + * @author Geng Rong + */ + +@JsonInclude(Include.NON_NULL) +public class ResponseFormat { + + /** + * Type Must be one of 'text', 'json_object'. + */ + @JsonProperty("type") + private Type type; + + public Type getType() { + return this.type; + } + + public void setType(Type type) { + this.type = type; + } + + private ResponseFormat(Type type) { + this.type = type; + } + + public static Builder builder() { + return new Builder(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ResponseFormat that = (ResponseFormat) o; + return this.type == that.type; + } + + @Override + public int hashCode() { + return Objects.hash(this.type); + } + + @Override + public String toString() { + return "ResponseFormat{" + "type=" + this.type + '}'; + } + + public static final class Builder { + + private Type type; + + private Builder() { + } + + public Builder type(Type type) { + this.type = type; + return this; + } + + public ResponseFormat build() { + return new ResponseFormat(this.type); + } + + } + + public enum Type { + + /** + * Generates a text response. (default) + */ + @JsonProperty("text") + TEXT, + + /** + * Enables JSON mode, which guarantees the message the model generates is valid + * JSON. + */ + @JsonProperty("json_object") + JSON_OBJECT, + + } + +} diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/common/DeepSeekConstants.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/common/DeepSeekConstants.java new file mode 100644 index 00000000000..904b8e9a916 --- /dev/null +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/common/DeepSeekConstants.java @@ -0,0 +1,37 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.deepseek.api.common; + +import org.springframework.ai.observation.conventions.AiProvider; + +/** + * @author Geng Rong + */ +public class DeepSeekConstants { + + public static final String DEFAULT_BASE_URL = "https://api.deepseek.com"; + + public static final String DEFAULT_COMPLETIONS_PATH = "/chat/completions"; + + public static final String DEFAULT_BETA_PATH = "/beta"; + + public static final String PROVIDER_NAME = AiProvider.DEEPSEEK.value(); + + private DeepSeekConstants() { + + } + +} diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/metadata/DeepSeekUsage.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/metadata/DeepSeekUsage.java new file mode 100644 index 00000000000..d6a3793b75e --- /dev/null +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/metadata/DeepSeekUsage.java @@ -0,0 +1,62 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.deepseek.metadata; + +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.deepseek.api.DeepSeekApi; +import org.springframework.util.Assert; + +/** + * @author Geng Rong + */ +public class DeepSeekUsage implements Usage { + + public static DeepSeekUsage from(DeepSeekApi.Usage usage) { + return new DeepSeekUsage(usage); + } + + private final DeepSeekApi.Usage usage; + + protected DeepSeekUsage(DeepSeekApi.Usage usage) { + Assert.notNull(usage, "DeepSeek Usage must not be null"); + this.usage = usage; + } + + protected DeepSeekApi.Usage getUsage() { + return this.usage; + } + + @Override + public Long getPromptTokens() { + return getUsage().promptTokens().longValue(); + } + + @Override + public Long getGenerationTokens() { + return getUsage().completionTokens().longValue(); + } + + @Override + public Long getTotalTokens() { + return getUsage().totalTokens().longValue(); + } + + @Override + public String toString() { + return getUsage().toString(); + } + +} diff --git a/models/spring-ai-deepseek/src/main/resources/META-INF/spring/aot.factories b/models/spring-ai-deepseek/src/main/resources/META-INF/spring/aot.factories new file mode 100644 index 00000000000..112c3a5eeb7 --- /dev/null +++ b/models/spring-ai-deepseek/src/main/resources/META-INF/spring/aot.factories @@ -0,0 +1,2 @@ +org.springframework.aot.hint.RuntimeHintsRegistrar=\ + org.springframework.ai.deepseek.aot.DeepSeekRuntimeHints \ No newline at end of file diff --git a/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/DeepSeekChatCompletionRequestTests.java b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/DeepSeekChatCompletionRequestTests.java new file mode 100644 index 00000000000..331df486bef --- /dev/null +++ b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/DeepSeekChatCompletionRequestTests.java @@ -0,0 +1,53 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.deepseek; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.deepseek.api.DeepSeekApi; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Geng Rong + */ +public class DeepSeekChatCompletionRequestTests { + + @Test + public void createRequestWithChatOptions() { + + var client = new DeepSeekChatModel(new DeepSeekApi("TEST"), + DeepSeekChatOptions.builder().model("DEFAULT_MODEL").temperature(66.6D).build()); + + var request = client.createRequest(new Prompt("Test message content"), false); + + assertThat(request.messages()).hasSize(1); + assertThat(request.stream()).isFalse(); + + assertThat(request.model()).isEqualTo("DEFAULT_MODEL"); + assertThat(request.temperature()).isEqualTo(66.6D); + + request = client.createRequest(new Prompt("Test message content", + DeepSeekChatOptions.builder().model("PROMPT_MODEL").temperature(99.9D).build()), true); + + assertThat(request.messages()).hasSize(1); + assertThat(request.stream()).isTrue(); + + assertThat(request.model()).isEqualTo("PROMPT_MODEL"); + assertThat(request.temperature()).isEqualTo(99.9D); + } + +} diff --git a/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/DeepSeekRetryTests.java b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/DeepSeekRetryTests.java new file mode 100644 index 00000000000..2d89793f356 --- /dev/null +++ b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/DeepSeekRetryTests.java @@ -0,0 +1,148 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.deepseek; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.deepseek.api.DeepSeekApi; +import org.springframework.ai.deepseek.api.DeepSeekApi.*; +import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage.Role; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.retry.TransientAiException; +import org.springframework.http.ResponseEntity; +import org.springframework.retry.RetryCallback; +import org.springframework.retry.RetryContext; +import org.springframework.retry.RetryListener; +import org.springframework.retry.support.RetryTemplate; + +import java.util.List; +import java.util.Optional; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.BDDMockito.given; + +/** + * @author Geng Rong + */ +@SuppressWarnings("unchecked") +@ExtendWith(MockitoExtension.class) +public class DeepSeekRetryTests { + + private TestRetryListener retryListener; + + private @Mock DeepSeekApi deepSeekApi; + + private DeepSeekChatModel chatModel; + + @BeforeEach + public void beforeEach() { + RetryTemplate retryTemplate = RetryUtils.SHORT_RETRY_TEMPLATE; + this.retryListener = new TestRetryListener(); + retryTemplate.registerListener(this.retryListener); + + this.chatModel = new DeepSeekChatModel(this.deepSeekApi, + DeepSeekChatOptions.builder() + .temperature(0.7) + .topP(1.0) + .model(ChatModel.DEEPSEEK_CHAT.getValue()) + .build(), + null, retryTemplate); + } + + @Test + public void deepSeekChatTransientError() { + + var choice = new ChatCompletion.Choice(ChatCompletionFinishReason.STOP, 0, + new ChatCompletionMessage("Response", Role.ASSISTANT), null); + ChatCompletion expectedChatCompletion = new ChatCompletion("id", List.of(choice), 789L, "model", null, + "chat.completion", new DeepSeekApi.Usage(10, 10, 10)); + + given(this.deepSeekApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + .willThrow(new TransientAiException("Transient Error 1")) + .willThrow(new TransientAiException("Transient Error 2")) + .willReturn(ResponseEntity.of(Optional.of(expectedChatCompletion))); + + var result = this.chatModel.call(new Prompt("text")); + + assertThat(result).isNotNull(); + assertThat(result.getResult().getOutput().getText()).isSameAs("Response"); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); + } + + @Test + public void deepSeekChatNonTransientError() { + given(this.deepSeekApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + .willThrow(new RuntimeException("Non Transient Error")); + assertThrows(RuntimeException.class, () -> this.chatModel.call(new Prompt("text"))); + } + + @Test + public void deepSeekChatStreamTransientError() { + + var choice = new ChatCompletion.Choice(ChatCompletionFinishReason.STOP, 0, + new ChatCompletionMessage("Response", Role.ASSISTANT), null); + ChatCompletion expectedChatCompletion = new ChatCompletion("id", List.of(choice), 666L, "model", null, + "chat.completion", new DeepSeekApi.Usage(10, 10, 10)); + + given(this.deepSeekApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + .willThrow(new TransientAiException("Transient Error 1")) + .willThrow(new TransientAiException("Transient Error 2")) + .willReturn(ResponseEntity.of(Optional.of(expectedChatCompletion))); + + var result = this.chatModel.call(new Prompt("text")); + + assertThat(result).isNotNull(); + assertThat(result.getResult().getOutput().getText()).isSameAs("Response"); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); + } + + @Test + public void deepSeekChatStreamNonTransientError() { + given(this.deepSeekApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + .willThrow(new RuntimeException("Non Transient Error")); + assertThrows(RuntimeException.class, () -> this.chatModel.stream(new Prompt("text")).collectList().block()); + } + + private static class TestRetryListener implements RetryListener { + + int onErrorRetryCount = 0; + + int onSuccessRetryCount = 0; + + @Override + public void onSuccess(RetryContext context, RetryCallback callback, T result) { + this.onSuccessRetryCount = context.getRetryCount(); + } + + @Override + public void onError(RetryContext context, RetryCallback callback, + Throwable throwable) { + this.onErrorRetryCount = context.getRetryCount(); + } + + } + +} diff --git a/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/DeepSeekTestConfiguration.java b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/DeepSeekTestConfiguration.java new file mode 100644 index 00000000000..63d264edba7 --- /dev/null +++ b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/DeepSeekTestConfiguration.java @@ -0,0 +1,48 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.deepseek; + +import org.springframework.ai.deepseek.api.DeepSeekApi; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.context.annotation.Bean; +import org.springframework.util.StringUtils; + +/** + * @author Geng Rong + */ +@SpringBootConfiguration +public class DeepSeekTestConfiguration { + + @Bean + public DeepSeekApi deepSeekApi() { + return new DeepSeekApi(getApiKey()); + } + + private String getApiKey() { + String apiKey = System.getenv("DEEPSEEK_API_KEY"); + if (!StringUtils.hasText(apiKey)) { + throw new IllegalArgumentException( + "You must provide an API key. Put it in an environment variable under the name DEEPSEEK_API_KEY"); + } + return apiKey; + } + + @Bean + public DeepSeekChatModel deepSeekChatModel(DeepSeekApi api) { + return new DeepSeekChatModel(api); + } + +} diff --git a/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/aot/DeepSeekRuntimeHintsTests.java b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/aot/DeepSeekRuntimeHintsTests.java new file mode 100644 index 00000000000..089db117125 --- /dev/null +++ b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/aot/DeepSeekRuntimeHintsTests.java @@ -0,0 +1,46 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.deepseek.aot; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.deepseek.api.DeepSeekApi; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.TypeReference; + +import java.util.Set; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; +import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection; + +/** + * @author Geng Rong + */ +class DeepSeekRuntimeHintsTests { + + @Test + void registerHints() { + RuntimeHints runtimeHints = new RuntimeHints(); + DeepSeekRuntimeHints deepSeekRuntimeHints = new DeepSeekRuntimeHints(); + deepSeekRuntimeHints.registerHints(runtimeHints, null); + + Set jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage(DeepSeekApi.class); + for (TypeReference jsonAnnotatedClass : jsonAnnotatedClasses) { + assertThat(runtimeHints).matches(reflection().onType(jsonAnnotatedClass)); + } + } + +} diff --git a/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/api/DeepSeekApiIT.java b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/api/DeepSeekApiIT.java new file mode 100644 index 00000000000..3516ac22d10 --- /dev/null +++ b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/api/DeepSeekApiIT.java @@ -0,0 +1,57 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.deepseek.api; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.deepseek.api.DeepSeekApi.*; +import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage.Role; +import org.springframework.http.ResponseEntity; +import reactor.core.publisher.Flux; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Geng Rong + */ +@EnabledIfEnvironmentVariable(named = "DEEPSEEK_API_KEY", matches = ".+") +public class DeepSeekApiIT { + + DeepSeekApi DeepSeekApi = new DeepSeekApi(System.getenv("DEEPSEEK_API_KEY")); + + @Test + void chatCompletionEntity() { + ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); + ResponseEntity response = DeepSeekApi.chatCompletionEntity( + new ChatCompletionRequest(List.of(chatCompletionMessage), ChatModel.DEEPSEEK_CHAT.value, 1D, false)); + + assertThat(response).isNotNull(); + assertThat(response.getBody()).isNotNull(); + } + + @Test + void chatCompletionStream() { + ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); + Flux response = DeepSeekApi.chatCompletionStream( + new ChatCompletionRequest(List.of(chatCompletionMessage), ChatModel.DEEPSEEK_CHAT.value, 1D, true)); + + assertThat(response).isNotNull(); + assertThat(response.collectList().block()).isNotNull(); + } + +} diff --git a/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/api/MockWeatherService.java b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/api/MockWeatherService.java new file mode 100644 index 00000000000..060c6594706 --- /dev/null +++ b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/api/MockWeatherService.java @@ -0,0 +1,95 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.deepseek.api; + +import com.fasterxml.jackson.annotation.JsonClassDescription; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; + +import java.util.function.Function; + +/** + * @author Geng Rong + */ +public class MockWeatherService implements Function { + + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, request.unit); + } + + /** + * Temperature units. + */ + public enum Unit { + + /** + * Celsius. + */ + C("metric"), + /** + * Fahrenheit. + */ + F("imperial"); + + /** + * Human readable unit name. + */ + public final String unitName; + + Unit(String text) { + this.unitName = text; + } + + } + + /** + * Weather Function request. + */ + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Weather API request") + public record Request(@JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @JsonProperty("lat") @JsonPropertyDescription("The city latitude") double lat, + @JsonProperty("lon") @JsonPropertyDescription("The city longitude") double lon, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + + } + + /** + * Weather Function response. + */ + public record Response(double temp, double feels_like, double temp_min, double temp_max, int pressure, int humidity, + Unit unit) { + + } + +} diff --git a/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/chat/ActorsFilms.java b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/chat/ActorsFilms.java new file mode 100644 index 00000000000..53f529ef3e4 --- /dev/null +++ b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/chat/ActorsFilms.java @@ -0,0 +1,53 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.deepseek.chat; + +import java.util.List; + +/** + * @author Geng Rong + */ +public class ActorsFilms { + + private String actor; + + private List movies; + + public ActorsFilms() { + } + + public String getActor() { + return actor; + } + + public void setActor(String actor) { + this.actor = actor; + } + + public List getMovies() { + return movies; + } + + public void setMovies(List movies) { + this.movies = movies; + } + + @Override + public String toString() { + return "ActorsFilms{" + "actor='" + actor + '\'' + ", movies=" + movies + '}'; + } + +} diff --git a/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/chat/DeepSeekChatModelFunctionCallingIT.java b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/chat/DeepSeekChatModelFunctionCallingIT.java new file mode 100644 index 00000000000..d877370ef81 --- /dev/null +++ b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/chat/DeepSeekChatModelFunctionCallingIT.java @@ -0,0 +1,185 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.deepseek.chat; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.deepseek.DeepSeekChatOptions; +import org.springframework.ai.deepseek.DeepSeekTestConfiguration; +import org.springframework.ai.deepseek.api.DeepSeekApi; +import org.springframework.ai.deepseek.api.MockWeatherService; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import reactor.core.publisher.Flux; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Geng Rong + */ +@SpringBootTest(classes = DeepSeekTestConfiguration.class) +@EnabledIfEnvironmentVariable(named = "DEEPSEEK_API_KEY", matches = ".+") +class DeepSeekChatModelFunctionCallingIT { + + private static final Logger logger = LoggerFactory.getLogger(DeepSeekChatModelFunctionCallingIT.class); + + @Autowired + ChatModel chatModel; + + private static final DeepSeekApi.FunctionTool FUNCTION_TOOL = new DeepSeekApi.FunctionTool( + DeepSeekApi.FunctionTool.Type.FUNCTION, new DeepSeekApi.FunctionTool.Function( + "Get the weather in location. Return temperature in 30°F or 30°C format.", "getCurrentWeather", """ + { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state e.g. San Francisco, CA" + }, + "lat": { + "type": "number", + "description": "The city latitude" + }, + "lon": { + "type": "number", + "description": "The city longitude" + }, + "unit": { + "type": "string", + "enum": ["C", "F"] + } + }, + "required": ["location", "lat", "lon", "unit"] + } + """)); + + @Test + void functionCallTest() { + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius."); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = DeepSeekChatOptions.builder() + .model(DeepSeekApi.ChatModel.DEEPSEEK_CHAT.getValue()) + .functionCallbacks(List.of(FunctionCallback.builder() + .function("getCurrentWeather", new MockWeatherService()) + .description("Get the weather in location") + .inputType(MockWeatherService.Request.class) + .build())) + .build(); + + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15"); + } + + @Test + void streamFunctionCallTest() { + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius."); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = DeepSeekChatOptions.builder() + .functionCallbacks(List.of(FunctionCallback.builder() + .function("getCurrentWeather", new MockWeatherService()) + .description("Get the weather in location") + .inputType(MockWeatherService.Request.class) + .build())) + .build(); + + Flux response = this.chatModel.stream(new Prompt(messages, promptOptions)); + + String content = response.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getText) + .filter(Objects::nonNull) + .collect(Collectors.joining()); + logger.info("Response: {}", content); + + assertThat(content).contains("30", "10", "15"); + } + + @Test + public void toolFunctionCallWithUsage() { + var promptOptions = DeepSeekChatOptions.builder() + .model(DeepSeekApi.ChatModel.DEEPSEEK_CHAT.getValue()) + .tools(Arrays.asList(FUNCTION_TOOL)) + .functionCallbacks(List.of(FunctionCallback.builder() + .function("getCurrentWeather", new MockWeatherService()) + .description("Get the weather in location. Return temperature in 36°F or 36°C format.") + .inputType(MockWeatherService.Request.class) + .build())) + .build(); + Prompt prompt = new Prompt("What's the weather like in San Francisco? Return the temperature in Celsius.", + promptOptions); + + ChatResponse chatResponse = this.chatModel.call(prompt); + assertThat(chatResponse).isNotNull(); + assertThat(chatResponse.getResult().getOutput()); + assertThat(chatResponse.getResult().getOutput().getText()).contains("San Francisco"); + assertThat(chatResponse.getResult().getOutput().getText()).contains("30.0"); + assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isLessThan(450).isGreaterThan(280); + } + + @Test + public void testStreamFunctionCallUsage() { + var promptOptions = DeepSeekChatOptions.builder() + .model(DeepSeekApi.ChatModel.DEEPSEEK_CHAT.getValue()) + .tools(Arrays.asList(FUNCTION_TOOL)) + .functionCallbacks(List.of(FunctionCallback.builder() + .function("getCurrentWeather", new MockWeatherService()) + .description("Get the weather in location. Return temperature in 36°F or 36°C format.") + .inputType(MockWeatherService.Request.class) + .build())) + .build(); + Prompt prompt = new Prompt("What's the weather like in San Francisco? Return the temperature in Celsius.", + promptOptions); + + ChatResponse chatResponse = this.chatModel.stream(prompt).blockLast(); + assertThat(chatResponse).isNotNull(); + assertThat(chatResponse.getMetadata()).isNotNull(); + assertThat(chatResponse.getMetadata().getUsage()).isNotNull(); + assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isLessThan(450).isGreaterThan(280); + } + +} diff --git a/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/chat/DeepSeekChatModelIT.java b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/chat/DeepSeekChatModelIT.java new file mode 100644 index 00000000000..d97ffa9a472 --- /dev/null +++ b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/chat/DeepSeekChatModelIT.java @@ -0,0 +1,192 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.deepseek.chat; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.model.StreamingChatModel; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.converter.BeanOutputConverter; +import org.springframework.ai.converter.ListOutputConverter; +import org.springframework.ai.converter.MapOutputConverter; +import org.springframework.ai.deepseek.DeepSeekTestConfiguration; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.core.convert.support.DefaultConversionService; +import org.springframework.core.io.Resource; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Geng Rong + */ +@SpringBootTest(classes = DeepSeekTestConfiguration.class) +@EnabledIfEnvironmentVariable(named = "DEEPSEEK_API_KEY", matches = ".+") +class DeepSeekChatModelIT { + + @Autowired + protected ChatModel chatModel; + + @Autowired + protected StreamingChatModel streamingChatModel; + + private static final Logger logger = LoggerFactory.getLogger(DeepSeekChatModelIT.class); + + @Value("classpath:/prompts/system-message.st") + private Resource systemResource; + + @Test + void roleTest() { + UserMessage userMessage = new UserMessage( + "Tell me about 3 famous pirates from the Golden Age of Piracy and what they did."); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); + Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); + ChatResponse response = chatModel.call(prompt); + assertThat(response.getResults()).hasSize(1); + assertThat(response.getResults().get(0).getOutput().getText()).contains("Blackbeard"); + // needs fine tuning... evaluateQuestionAndAnswer(request, response, false); + } + + @Test + void listOutputConverter() { + DefaultConversionService conversionService = new DefaultConversionService(); + ListOutputConverter outputConverter = new ListOutputConverter(conversionService); + + String format = outputConverter.getFormat(); + String template = """ + List five {subject} + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, + Map.of("subject", "ice cream flavors", "format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = this.chatModel.call(prompt).getResult(); + + List list = outputConverter.convert(generation.getOutput().getText()); + assertThat(list).hasSize(5); + + } + + @Test + void mapOutputConverter() { + MapOutputConverter outputConverter = new MapOutputConverter(); + + String format = outputConverter.getFormat(); + String template = """ + Please provide the JSON response without any code block markers such as ```json```. + Provide me a List of {subject} + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, + Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = chatModel.call(prompt).getResult(); + + Map result = outputConverter.convert(generation.getOutput().getText()); + assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); + + } + + @Test + void beanOutputConverter() { + + BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilms.class); + + String format = outputConverter.getFormat(); + String template = """ + Generate the filmography for a random actor. + Please provide the JSON response without any code block markers such as ```json```. + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = chatModel.call(prompt).getResult(); + + ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getText()); + } + + record ActorsFilmsRecord(String actor, List movies) { + } + + @Test + void beanOutputConverterRecords() { + + BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class); + + String format = outputConverter.getFormat(); + String template = """ + Generate the filmography of 5 movies for Tom Hanks. + Please provide the JSON response without any code block markers such as ```json```. + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = chatModel.call(prompt).getResult(); + + ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getText()); + logger.info("" + actorsFilms); + assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); + assertThat(actorsFilms.movies()).hasSize(5); + } + + @Test + void beanStreamOutputConverterRecords() { + + BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class); + + String format = outputConverter.getFormat(); + String template = """ + Generate the filmography of 5 movies for Tom Hanks. + Please provide the JSON response without any code block markers such as ```json```. + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + + String generationTextFromStream = streamingChatModel.stream(prompt) + .collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getText) + .collect(Collectors.joining()); + + ActorsFilmsRecord actorsFilms = outputConverter.convert(generationTextFromStream); + logger.info("" + actorsFilms); + assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); + assertThat(actorsFilms.movies()).hasSize(5); + } + +} diff --git a/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/chat/DeepSeekChatModelObservationIT.java b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/chat/DeepSeekChatModelObservationIT.java new file mode 100644 index 00000000000..69e4b908c0e --- /dev/null +++ b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/chat/DeepSeekChatModelObservationIT.java @@ -0,0 +1,180 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.deepseek.chat; + +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.deepseek.DeepSeekChatModel; +import org.springframework.ai.deepseek.DeepSeekChatOptions; +import org.springframework.ai.deepseek.api.DeepSeekApi; +import org.springframework.ai.model.function.DefaultFunctionCallbackResolver; +import org.springframework.ai.observation.conventions.AiOperationType; +import org.springframework.ai.observation.conventions.AiProvider; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.springframework.retry.support.RetryTemplate; +import reactor.core.publisher.Flux; + +import java.util.List; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.HighCardinalityKeyNames; +import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.LowCardinalityKeyNames; + +/** + * Integration tests for observation instrumentation in {@link DeepSeekChatModel}. + * + * @author Geng Rong + */ +@SpringBootTest(classes = DeepSeekChatModelObservationIT.Config.class) +@EnabledIfEnvironmentVariable(named = "DEEPSEEK_API_KEY", matches = ".+") +public class DeepSeekChatModelObservationIT { + + @Autowired + TestObservationRegistry observationRegistry; + + @Autowired + DeepSeekChatModel chatModel; + + @BeforeEach + void beforeEach() { + this.observationRegistry.clear(); + } + + @Test + void observationForChatOperation() { + var options = DeepSeekChatOptions.builder() + .model(DeepSeekApi.ChatModel.DEEPSEEK_CHAT.getValue()) + .frequencyPenalty(0.0) + .maxTokens(2048) + .presencePenalty(0.0) + .stop(List.of("this-is-the-end")) + .temperature(0.7) + .topP(1.0) + .build(); + + Prompt prompt = new Prompt("Why does a raven look like a desk?", options); + + ChatResponse chatResponse = this.chatModel.call(prompt); + assertThat(chatResponse.getResult().getOutput().getText()).isNotEmpty(); + + ChatResponseMetadata responseMetadata = chatResponse.getMetadata(); + assertThat(responseMetadata).isNotNull(); + + validate(responseMetadata); + } + + @Test + void observationForStreamingChatOperation() { + var options = DeepSeekChatOptions.builder() + .model(DeepSeekApi.ChatModel.DEEPSEEK_CHAT.getValue()) + .frequencyPenalty(0.0) + .maxTokens(2048) + .presencePenalty(0.0) + .stop(List.of("this-is-the-end")) + .temperature(0.7) + .topP(1.0) + .build(); + + Prompt prompt = new Prompt("Why does a raven look like a desk?", options); + + Flux chatResponseFlux = this.chatModel.stream(prompt); + + List responses = chatResponseFlux.collectList().block(); + assertThat(responses).isNotEmpty(); + assertThat(responses).hasSizeGreaterThan(10); + + String aggregatedResponse = responses.subList(0, responses.size() - 1) + .stream() + .map(r -> r.getResult().getOutput().getText()) + .collect(Collectors.joining()); + assertThat(aggregatedResponse).isNotEmpty(); + + ChatResponse lastChatResponse = responses.get(responses.size() - 1); + + ChatResponseMetadata responseMetadata = lastChatResponse.getMetadata(); + assertThat(responseMetadata).isNotNull(); + + validate(responseMetadata); + } + + private void validate(ChatResponseMetadata responseMetadata) { + TestObservationRegistryAssert.assertThat(this.observationRegistry) + .doesNotHaveAnyRemainingCurrentObservation() + .hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME) + .that() + .hasContextualNameEqualTo("chat " + DeepSeekApi.ChatModel.DEEPSEEK_CHAT.getValue()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), + AiOperationType.CHAT.value()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_PROVIDER.asString(), AiProvider.DEEPSEEK.value()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.REQUEST_MODEL.asString(), + DeepSeekApi.ChatModel.DEEPSEEK_CHAT.getValue()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.RESPONSE_MODEL.asString(), responseMetadata.getModel()) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_FREQUENCY_PENALTY.asString(), "0.0") + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_MAX_TOKENS.asString(), "2048") + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_PRESENCE_PENALTY.asString(), "0.0") + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_STOP_SEQUENCES.asString(), + "[\"this-is-the-end\"]") + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_TEMPERATURE.asString(), "0.7") + .doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.REQUEST_TOP_K.asString()) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_TOP_P.asString(), "1.0") + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.RESPONSE_ID.asString(), responseMetadata.getId()) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.RESPONSE_FINISH_REASONS.asString(), "[\"STOP\"]") + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(), + String.valueOf(responseMetadata.getUsage().getPromptTokens())) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_OUTPUT_TOKENS.asString(), + String.valueOf(responseMetadata.getUsage().getGenerationTokens())) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(), + String.valueOf(responseMetadata.getUsage().getTotalTokens())) + .hasBeenStarted() + .hasBeenStopped(); + } + + @SpringBootConfiguration + static class Config { + + @Bean + public TestObservationRegistry observationRegistry() { + return TestObservationRegistry.create(); + } + + @Bean + public DeepSeekApi deepSeekApi() { + return new DeepSeekApi(System.getenv("DEEPSEEK_API_KEY")); + } + + @Bean + public DeepSeekChatModel deepSeekChatModel(DeepSeekApi deepSeekApi, + TestObservationRegistry observationRegistry) { + return new DeepSeekChatModel(deepSeekApi, DeepSeekChatOptions.builder().build(), + new DefaultFunctionCallbackResolver(), List.of(), RetryTemplate.defaultInstance(), + observationRegistry); + } + + } + +} diff --git a/models/spring-ai-deepseek/src/test/resources/prompts/system-message.st b/models/spring-ai-deepseek/src/test/resources/prompts/system-message.st new file mode 100644 index 00000000000..dc2cf2dcd84 --- /dev/null +++ b/models/spring-ai-deepseek/src/test/resources/prompts/system-message.st @@ -0,0 +1,4 @@ +"You are a helpful AI assistant. Your name is {name}. +You are an AI assistant that helps people find information. +Your name is {name} +You should reply to the user's request with your name and also in the style of a {voice}. \ No newline at end of file diff --git a/pom.xml b/pom.xml index 5d2159e1775..78ff4ab4da9 100644 --- a/pom.xml +++ b/pom.xml @@ -91,6 +91,7 @@ models/spring-ai-azure-openai models/spring-ai-bedrock models/spring-ai-bedrock-converse + models/spring-ai-deepseek models/spring-ai-huggingface models/spring-ai-minimax models/spring-ai-mistral-ai @@ -111,6 +112,7 @@ spring-ai-spring-boot-starters/spring-ai-starter-azure-openai spring-ai-spring-boot-starters/spring-ai-starter-bedrock-ai spring-ai-spring-boot-starters/spring-ai-starter-bedrock-converse + spring-ai-spring-boot-starters/spring-ai-starter-deepseek spring-ai-spring-boot-starters/spring-ai-starter-huggingface spring-ai-spring-boot-starters/spring-ai-starter-minimax spring-ai-spring-boot-starters/spring-ai-starter-mistral-ai diff --git a/spring-ai-bom/pom.xml b/spring-ai-bom/pom.xml index fb8fee83f4e..f665b2e9743 100644 --- a/spring-ai-bom/pom.xml +++ b/spring-ai-bom/pom.xml @@ -176,6 +176,12 @@ ${project.version} + + org.springframework.ai + spring-ai-deepseek + ${project.version} + + org.springframework.ai @@ -581,6 +587,12 @@ ${project.version} + + org.springframework.ai + spring-ai-deepseek-spring-boot-starter + ${project.version} + + diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java index e723b679b02..f66b024e561 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java @@ -86,6 +86,11 @@ public enum AiProvider { */ ZHIPUAI("zhipuai"), + /** + * AI system provided by DeepSeek. + */ + DEEPSEEK("deepseek"), + /** * AI system provided by Spring AI. */ diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc index d4234afe76e..88b58f6caef 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc @@ -22,6 +22,7 @@ **** xref:api/chat/vertexai-gemini-chat.adoc[VertexAI Gemini] ***** xref:api/chat/functions/vertexai-gemini-chat-functions.adoc[Gemini Function Calling] *** xref:api/chat/groq-chat.adoc[Groq] +*** xref:api/chat/deepseek-chat.adoc[DeepSeek] *** xref:api/chat/huggingface.adoc[Hugging Face] *** xref:api/chat/mistralai-chat.adoc[Mistral AI] **** xref:api/chat/functions/mistralai-chat-functions.adoc[Mistral Function Calling] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/deepseek-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/deepseek-chat.adoc new file mode 100644 index 00000000000..54f433c45e5 --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/deepseek-chat.adoc @@ -0,0 +1,251 @@ += DeepSeek Chat + +Spring AI supports the various AI language models from DeepSeek. You can interact with DeepSeek language models and create a multilingual conversational assistant based on DeepSeek models. + +== Prerequisites + +You will need to create an API with DeepSeek to access DeepSeek language models. +Create an account at https://platform.deepseek.com/sign_up[DeepSeek registration page] and generate the token on the https://platform.deepseek.com/api_keys[API Keys page]. +The Spring AI project defines a configuration property named `spring.ai.deepseek.api-key` that you should set to the value of the `API Key` obtained from https://platform.deepseek.com/api_keys[API Keys page]. +Exporting an environment variable is one way to set that configuration property: + +[source,shell] +---- +export SPRING_AI_DEEPSEEK_AI_API_KEY= +---- + +=== Add Repositories and BOM + +Spring AI artifacts are published in Spring Milestone and Snapshot repositories. +Refer to the xref:getting-started.adoc#repositories[Repositories] section to add these repositories to your build system. + +To help with dependency management, Spring AI provides a BOM (bill of materials) to ensure that a consistent version of Spring AI is used throughout the entire project. Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build system. + + + +== Auto-configuration + +Spring AI provides Spring Boot auto-configuration for the DeepSeek Chat Model. +To enable it add the following dependency to your project's Maven `pom.xml` file: + +[source, xml] +---- + + org.springframework.ai + spring-ai-deepseek-spring-boot-starter + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-deepseek-spring-boot-starter' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +=== Chat Properties + +==== Retry Properties + +The prefix `spring.ai.retry` is used as the property prefix that lets you configure the retry mechanism for the DeepSeek Chat model. + +[cols="3,5,1"] +|==== +| Property | Description | Default + +| spring.ai.retry.max-attempts | Maximum number of retry attempts. | 10 +| spring.ai.retry.backoff.initial-interval | Initial sleep duration for the exponential backoff policy. | 2 sec. +| spring.ai.retry.backoff.multiplier | Backoff interval multiplier. | 5 +| spring.ai.retry.backoff.max-interval | Maximum backoff duration. | 3 min. +| spring.ai.retry.on-client-errors | If false, throw a NonTransientAiException, and do not attempt retry for `4xx` client error codes | false +| spring.ai.retry.exclude-on-http-codes | List of HTTP status codes that should not trigger a retry (e.g. to throw NonTransientAiException). | empty +| spring.ai.retry.on-http-codes | List of HTTP status codes that should trigger a retry (e.g. to throw TransientAiException). | empty +|==== + +==== Connection Properties + +The prefix `spring.ai.deepseek` is used as the property prefix that lets you connect to DeepSeek. + +[cols="3,5,1"] +|==== +| Property | Description | Default + +| spring.ai.deepseek.base-url | The URL to connect to | https://api.deepseek.com +| spring.ai.deepseek.api-key | The API Key | - +|==== + +==== Configuration Properties + +The prefix `spring.ai.deepseek.chat` is the property prefix that lets you configure the chat model implementation for DeepSeek. + +[cols="3,5,1"] +|==== +| Property | Description | Default + +| spring.ai.deepseek.chat.enabled | Enable DeepSeek chat model. | true +| spring.ai.deepseek.chat.base-url | Optional overrides the spring.ai.deepseek.base-url to provide chat specific url | https://api.deepseek.com/ +| spring.ai.deepseek.chat.api-key | Optional overrides the spring.ai.deepseek.api-key to provide chat specific api-key | - +| spring.ai.deepseek.chat.completions-path | the path to the chat completions endpoint | /chat/completions +| spring.ai.deepseek.chat.beta-feature-path | the path to the beta feature endpoint | /beta +| spring.ai.deepseek.chat.options.model | ID of the model to use. You can use either use deepseek-coder or deepseek-chat. | deepseek-chat +| spring.ai.deepseek.chat.options.frequencyPenalty | Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. | 0.0f +| spring.ai.deepseek.chat.options.maxTokens | The maximum number of tokens to generate in the chat completion. The total length of input tokens and generated tokens is limited by the model's context length. | - +| spring.ai.deepseek.chat.options.presencePenalty | Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. | 0.0f +| spring.ai.deepseek.chat.options.stop | Up to 4 sequences where the API will stop generating further tokens. | - +| spring.ai.deepseek.chat.options.temperature | What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or top_p but not both. | 1.0F +| spring.ai.deepseek.chat.options.topP | An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both. | 1.0F +| spring.ai.deepseek.chat.options.logprobs | Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message. | - +| spring.ai.deepseek.chat.options.topLogprobs | An integer between 0 and 20 specifying the number of most likely tokens to return at each token position, each with an associated log probability. logprobs must be set to true if this parameter is used. | - +|==== + +NOTE: You can override the common `spring.ai.deepseek.base-url` and `spring.ai.deepseek.api-key` for the `ChatModel` implementations. +The `spring.ai.deepseek.chat.base-url` and `spring.ai.deepseek.chat.api-key` properties if set take precedence over the common properties. +This is useful if you want to use different DeepSeek accounts for different models and different model endpoints. + +TIP: All properties prefixed with `spring.ai.deepseek.chat.options` can be overridden at runtime by adding a request specific <> to the `Prompt` call. + +== Runtime Options [[chat-options]] + +The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatOptions.java[DeepSeekChatOptions.java] provides model configurations, such as the model to use, the temperature, the frequency penalty, etc. + +On start-up, the default options can be configured with the `DeepSeekChatModel(api, options)` constructor or the `spring.ai.deepseek.chat.options.*` properties. + +At run-time you can override the default options by adding new, request specific, options to the `Prompt` call. +For example to override the default model and temperature for a specific request: + +[source,java] +---- +ChatResponse response = chatModel.call( + new Prompt( + "Generate the names of 5 famous pirates. Please provide the JSON response without any code block markers such as ```json```.", + DeepSeekChatOptions.builder() + .withModel(DeepSeekApi.ChatModel.DEEPSEEK_CHAT.getValue()) + .withTemperature(0.8f) + .build() + )); +---- + +TIP: In addition to the model specific link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatOptions.java[DeepSeekChatOptions] you can use a portable https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptions.java[ChatOptions] instance, created with the https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptionsBuilder.java[ChatOptionsBuilder#builder()]. + +== Sample Controller (Auto-configuration) + +https://start.spring.io/[Create] a new Spring Boot project and add the `spring-ai-deepseek-spring-boot-starter` to your pom (or gradle) dependencies. + +Add a `application.properties` file, under the `src/main/resources` directory, to enable and configure the DeepSeek Chat model: + +[source,application.properties] +---- +spring.ai.deepseek.api-key=YOUR_API_KEY +spring.ai.deepseek.chat.options.model=deepseek-chat +spring.ai.deepseek.chat.options.temperature=0.8 +---- + +TIP: replace the `api-key` with your DeepSeek credentials. + +This will create a `DeepSeekChatModel` implementation that you can inject into your class. +Here is an example of a simple `@Controller` class that uses the chat model for text generations. + +[source,java] +---- +@RestController +public class ChatController { + + private final DeepSeekChatModel chatModel; + + @Autowired + public ChatController(DeepSeekChatModel chatModel) { + this.chatModel = chatModel; + } + + @GetMapping("/ai/generate") + public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { + return Map.of("generation", chatModel.call(message)); + } + + @GetMapping("/ai/generateStream") + public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { + var prompt = new Prompt(new UserMessage(message)); + return chatModel.stream(prompt); + } +} +---- + +== Manual Configuration + +The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java[DeepSeekChatModel] implements the `ChatModel` and `StreamingChatModel` and uses the <> to connect to the DeepSeek service. + +Add the `spring-ai-deepseek` dependency to your project's Maven `pom.xml` file: + +[source, xml] +---- + + org.springframework.ai + spring-ai-deepseek + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-deepseek' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +Next, create a `DeepSeekChatModel` and use it for text generations: + +[source,java] +---- +var deepSeekApi = new DeepSeekApi(System.getenv("DEEPSEEK_API_KEY")); + +var chatModel = new DeepSeekChatModel(deepSeekApi, DeepSeekChatOptions.builder() + .withModel(DeepSeekApi.ChatModel.DEEPSEEK_CHAT.getValue()) + .withTemperature(0.4f) + .withMaxTokens(200) + .build()); + +ChatResponse response = chatModel.call( + new Prompt("Generate the names of 5 famous pirates.")); + +// Or with streaming responses +Flux streamResponse = chatModel.stream( + new Prompt("Generate the names of 5 famous pirates.")); +---- + +The `DeepSeekChatOptions` provides the configuration information for the chat requests. +The `DeepSeekChatOptions.Builder` is fluent options builder. + +=== Low-level DeepSeekApi Client [[low-level-api]] + +The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekApi.java[DeepSeekApi] provides is lightweight Java client for link:https://platform.deepseek.com/api-docs/[DeepSeek API]. + +Here is a simple snippet how to use the api programmatically: + +[source,java] +---- +DeepSeekApi deepSeekApi = + new DeepSeekApi(System.getenv("DEEPSEEK_API_KEY")); + +ChatCompletionMessage chatCompletionMessage = + new ChatCompletionMessage("Hello world", Role.USER); + +// Sync request +ResponseEntity response = deepSeekApi.chatCompletionEntity( + new ChatCompletionRequest(List.of(chatCompletionMessage), DeepSeekApi.ChatModel.DEEPSEEK_CHAT.getValue(), 0.7f, false)); + +// Streaming request +Flux streamResponse = deepSeekApi.chatCompletionStream( + new ChatCompletionRequest(List.of(chatCompletionMessage), DeepSeekApi.ChatModel.DEEPSEEK_CHAT.getValue(), 0.7f, true)); +---- + +Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekApi.java[DeepSeekApi.java]'s JavaDoc for further information. + +==== DeepSeekApi Samples +* The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/api/DeepSeekApiIT.java[DeepSeekApiIT.java] test provides some general examples how to use the lightweight library. diff --git a/spring-ai-spring-boot-autoconfigure/pom.xml b/spring-ai-spring-boot-autoconfigure/pom.xml index 92ac01362c6..94a25b90751 100644 --- a/spring-ai-spring-boot-autoconfigure/pom.xml +++ b/spring-ai-spring-boot-autoconfigure/pom.xml @@ -366,6 +366,14 @@ true + + + org.springframework.ai + spring-ai-deepseek + ${project.parent.version} + true + + org.springframework.ai diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekAutoConfiguration.java new file mode 100644 index 00000000000..1b5157b9b71 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekAutoConfiguration.java @@ -0,0 +1,103 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.deepseek; + +import io.micrometer.observation.ObservationRegistry; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.chat.observation.ChatModelObservationConvention; +import org.springframework.ai.model.function.DefaultFunctionCallbackResolver; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallbackResolver; +import org.springframework.ai.deepseek.DeepSeekChatModel; +import org.springframework.ai.deepseek.api.DeepSeekApi; +import org.springframework.beans.factory.ObjectProvider; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; + +import java.util.List; + +/** + * {@link AutoConfiguration Auto-configuration} for DeepSeek Chat Model. + * + * @author Geng Rong + */ +@AutoConfiguration(after = { RestClientAutoConfiguration.class, SpringAiRetryAutoConfiguration.class }) +@EnableConfigurationProperties({ DeepSeekCommonProperties.class, DeepSeekChatProperties.class }) +@ConditionalOnClass(DeepSeekApi.class) +public class DeepSeekAutoConfiguration { + + @Bean + @ConditionalOnMissingBean + @ConditionalOnProperty(prefix = DeepSeekChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", + matchIfMissing = true) + public DeepSeekChatModel deepSeekChatModel(DeepSeekCommonProperties commonProperties, + DeepSeekChatProperties chatProperties, ObjectProvider restClientBuilderProvider, + ObjectProvider webClientBuilderProvider, List toolFunctionCallbacks, + FunctionCallbackResolver functionCallbackResolver, RetryTemplate retryTemplate, + ResponseErrorHandler responseErrorHandler, ObjectProvider observationRegistry, + ObjectProvider observationConvention) { + + var deepSeekApi = deepSeekApi(chatProperties, commonProperties, + restClientBuilderProvider.getIfAvailable(RestClient::builder), + webClientBuilderProvider.getIfAvailable(WebClient::builder), responseErrorHandler); + + var chatModel = new DeepSeekChatModel(deepSeekApi, chatProperties.getOptions(), functionCallbackResolver, + toolFunctionCallbacks, retryTemplate, observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)); + + observationConvention.ifAvailable(chatModel::setObservationConvention); + return chatModel; + } + + @Bean + @ConditionalOnMissingBean + public FunctionCallbackResolver springAiFunctionManager(ApplicationContext context) { + DefaultFunctionCallbackResolver manager = new DefaultFunctionCallbackResolver(); + manager.setApplicationContext(context); + return manager; + } + + private DeepSeekApi deepSeekApi(DeepSeekChatProperties chatProperties, DeepSeekCommonProperties commonProperties, + RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, + ResponseErrorHandler responseErrorHandler) { + + var resolvedApiKey = getTextOrElse(chatProperties.getApiKey(), commonProperties.getApiKey()); + var resoledBaseUrl = getTextOrElse(chatProperties.getBaseUrl(), commonProperties.getBaseUrl()); + + Assert.hasText(resolvedApiKey, "DeepSeek API key must be set"); + Assert.hasText(resoledBaseUrl, "DeepSeek base URL must be set"); + + return new DeepSeekApi(resoledBaseUrl, resolvedApiKey, chatProperties.getCompletionsPath(), + chatProperties.getBetaFeaturePath(), restClientBuilder, webClientBuilder, responseErrorHandler); + } + + private String getTextOrElse(String text, String defaultValue) { + return StringUtils.hasText(text) ? text : defaultValue; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekChatProperties.java new file mode 100644 index 00000000000..d33125d91e4 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekChatProperties.java @@ -0,0 +1,89 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.deepseek; + +import org.springframework.ai.deepseek.DeepSeekChatOptions; +import org.springframework.ai.deepseek.api.DeepSeekApi; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +/** + * Configuration properties for DeepSeek chat client. + * + * @author Geng Rong + */ +@ConfigurationProperties(DeepSeekChatProperties.CONFIG_PREFIX) +public class DeepSeekChatProperties extends DeepSeekParentProperties { + + public static final String CONFIG_PREFIX = "spring.ai.deepseek.chat"; + + public static final String DEFAULT_CHAT_MODEL = DeepSeekApi.ChatModel.DEEPSEEK_CHAT.getValue(); + + private static final Double DEFAULT_TEMPERATURE = 1D; + + public static final String DEFAULT_COMPLETIONS_PATH = "/chat/completions"; + + public static final String DEFAULT_BETA_PATH = "/beta"; + + /** + * Enable DeepSeek chat client. + */ + private boolean enabled = true; + + private String completionsPath = DEFAULT_COMPLETIONS_PATH; + + private String betaFeaturePath = DEFAULT_BETA_PATH; + + @NestedConfigurationProperty + private DeepSeekChatOptions options = DeepSeekChatOptions.builder() + .model(DEFAULT_CHAT_MODEL) + .temperature(DEFAULT_TEMPERATURE) + .build(); + + public DeepSeekChatOptions getOptions() { + return this.options; + } + + public void setOptions(DeepSeekChatOptions options) { + this.options = options; + } + + public boolean isEnabled() { + return this.enabled; + } + + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + + public String getCompletionsPath() { + return completionsPath; + } + + public void setCompletionsPath(String completionsPath) { + this.completionsPath = completionsPath; + } + + public String getBetaFeaturePath() { + return betaFeaturePath; + } + + public void setBetaFeaturePath(String betaFeaturePath) { + this.betaFeaturePath = betaFeaturePath; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekCommonProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekCommonProperties.java new file mode 100644 index 00000000000..e7f73305b2e --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekCommonProperties.java @@ -0,0 +1,37 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.deepseek; + +import org.springframework.boot.context.properties.ConfigurationProperties; + +/** + * Parent properties for DeepSeek. + * + * @author Geng Rong + */ +@ConfigurationProperties(DeepSeekCommonProperties.CONFIG_PREFIX) +public class DeepSeekCommonProperties extends DeepSeekParentProperties { + + public static final String CONFIG_PREFIX = "spring.ai.deepseek"; + + public static final String DEFAULT_BASE_URL = "https://api.deepseek.com"; + + public DeepSeekCommonProperties() { + super.setBaseUrl(DEFAULT_BASE_URL); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekParentProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekParentProperties.java new file mode 100644 index 00000000000..ba908716db4 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekParentProperties.java @@ -0,0 +1,46 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.deepseek; + +/** + * Parent properties for DeepSeek. + * + * @author Geng Rong + */ +public class DeepSeekParentProperties { + + private String apiKey; + + private String baseUrl; + + public String getApiKey() { + return this.apiKey; + } + + public void setApiKey(String apiKey) { + this.apiKey = apiKey; + } + + public String getBaseUrl() { + return this.baseUrl; + } + + public void setBaseUrl(String baseUrl) { + this.baseUrl = baseUrl; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports index f3e5633efc0..3829957e93f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports +++ b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -60,6 +60,7 @@ org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration org.springframework.ai.autoconfigure.chat.client.ChatClientAutoConfiguration org.springframework.ai.autoconfigure.vectorstore.typesense.TypesenseVectorStoreAutoConfiguration org.springframework.ai.autoconfigure.vectorstore.opensearch.OpenSearchVectorStoreAutoConfiguration +org.springframework.ai.autoconfigure.deepseek.DeepSeekAutoConfiguration org.springframework.ai.autoconfigure.moonshot.MoonshotAutoConfiguration org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration org.springframework.ai.autoconfigure.minimax.MiniMaxAutoConfiguration diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekAutoConfigurationIT.java new file mode 100644 index 00000000000..286695d0c77 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekAutoConfigurationIT.java @@ -0,0 +1,76 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.deepseek; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.deepseek.DeepSeekChatModel; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import reactor.core.publisher.Flux; + +import java.util.Objects; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Geng Rong + */ +@EnabledIfEnvironmentVariable(named = "DEEPSEEK_API_KEY", matches = ".*") +public class DeepSeekAutoConfigurationIT { + + private static final Log logger = LogFactory.getLog(DeepSeekAutoConfigurationIT.class); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.deepseek.apiKey=" + System.getenv("DEEPSEEK_API_KEY")) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, DeepSeekAutoConfiguration.class)); + + @Test + void generate() { + this.contextRunner.run(context -> { + DeepSeekChatModel client = context.getBean(DeepSeekChatModel.class); + String response = client.call("Hello"); + assertThat(response).isNotEmpty(); + logger.info("Response: " + response); + }); + } + + @Test + void generateStreaming() { + this.contextRunner.run(context -> { + DeepSeekChatModel client = context.getBean(DeepSeekChatModel.class); + Flux responseFlux = client.stream(new Prompt(new UserMessage("Hello"))); + String response = Objects.requireNonNull(responseFlux.collectList().block()) + .stream() + .map(chatResponse -> chatResponse.getResults().get(0).getOutput().getText()) + .collect(Collectors.joining()); + + assertThat(response).isNotEmpty(); + logger.info("Response: " + response); + }); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekPropertiesTests.java new file mode 100644 index 00000000000..6b3188c1ede --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekPropertiesTests.java @@ -0,0 +1,162 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.deepseek; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.deepseek.DeepSeekChatModel; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Geng Rong + */ +public class DeepSeekPropertiesTests { + + @Test + public void chatProperties() { + + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.deepseek.base-url=TEST_BASE_URL", + "spring.ai.deepseek.api-key=abc123", + "spring.ai.deepseek.chat.options.model=MODEL_XYZ", + "spring.ai.deepseek.chat.options.temperature=0.55") + // @formatter:on + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, DeepSeekAutoConfiguration.class)) + .run(context -> { + var chatProperties = context.getBean(DeepSeekChatProperties.class); + var connectionProperties = context.getBean(DeepSeekCommonProperties.class); + + assertThat(connectionProperties.getApiKey()).isEqualTo("abc123"); + assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + + assertThat(chatProperties.getApiKey()).isNull(); + assertThat(chatProperties.getBaseUrl()).isNull(); + + assertThat(chatProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + assertThat(chatProperties.getOptions().getTemperature()).isEqualTo(0.55); + }); + } + + @Test + public void chatOverrideConnectionProperties() { + + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.deepseek.base-url=TEST_BASE_URL", + "spring.ai.deepseek.api-key=abc123", + "spring.ai.deepseek.chat.base-url=TEST_BASE_URL2", + "spring.ai.deepseek.chat.api-key=456", + "spring.ai.deepseek.chat.options.model=MODEL_XYZ", + "spring.ai.deepseek.chat.options.temperature=0.55") + // @formatter:on + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, DeepSeekAutoConfiguration.class)) + .run(context -> { + var chatProperties = context.getBean(DeepSeekChatProperties.class); + var connectionProperties = context.getBean(DeepSeekCommonProperties.class); + + assertThat(connectionProperties.getApiKey()).isEqualTo("abc123"); + assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + + assertThat(chatProperties.getApiKey()).isEqualTo("456"); + assertThat(chatProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL2"); + + assertThat(chatProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + assertThat(chatProperties.getOptions().getTemperature()).isEqualTo(0.55); + }); + } + + @Test + public void chatOptionsTest() { + + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.deepseek.api-key=API_KEY", + "spring.ai.deepseek.base-url=TEST_BASE_URL", + + "spring.ai.deepseek.chat.options.model=MODEL_XYZ", + "spring.ai.deepseek.chat.options.frequencyPenalty=-1.5", + "spring.ai.deepseek.chat.options.logitBias.myTokenId=-5", + "spring.ai.deepseek.chat.options.maxTokens=123", + "spring.ai.deepseek.chat.options.n=10", + "spring.ai.deepseek.chat.options.presencePenalty=0", + "spring.ai.deepseek.chat.options.responseFormat.type=json", + "spring.ai.deepseek.chat.options.seed=66", + "spring.ai.deepseek.chat.options.stop=boza,koza", + "spring.ai.deepseek.chat.options.temperature=0.55", + "spring.ai.deepseek.chat.options.topP=0.56", + "spring.ai.deepseek.chat.options.user=userXYZ" + ) + // @formatter:on + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, DeepSeekAutoConfiguration.class)) + .run(context -> { + var chatProperties = context.getBean(DeepSeekChatProperties.class); + var connectionProperties = context.getBean(DeepSeekCommonProperties.class); + + assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + assertThat(connectionProperties.getApiKey()).isEqualTo("API_KEY"); + + assertThat(chatProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + assertThat(chatProperties.getOptions().getFrequencyPenalty()).isEqualTo(-1.5); + assertThat(chatProperties.getOptions().getMaxTokens()).isEqualTo(123); + assertThat(chatProperties.getOptions().getPresencePenalty()).isEqualTo(0); + assertThat(chatProperties.getOptions().getStop()).contains("boza", "koza"); + assertThat(chatProperties.getOptions().getTemperature()).isEqualTo(0.55); + assertThat(chatProperties.getOptions().getTopP()).isEqualTo(0.56); + }); + } + + @Test + void chatActivation() { + new ApplicationContextRunner() + .withPropertyValues("spring.ai.deepseek.api-key=API_KEY", "spring.ai.deepseek.base-url=TEST_BASE_URL", + "spring.ai.deepseek.chat.enabled=false") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, DeepSeekAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(DeepSeekChatProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(DeepSeekChatModel.class)).isEmpty(); + }); + + new ApplicationContextRunner() + .withPropertyValues("spring.ai.deepseek.api-key=API_KEY", "spring.ai.deepseek.base-url=TEST_BASE_URL") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, DeepSeekAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(DeepSeekChatProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(DeepSeekChatModel.class)).isNotEmpty(); + }); + + new ApplicationContextRunner() + .withPropertyValues("spring.ai.deepseek.api-key=API_KEY", "spring.ai.deepseek.base-url=TEST_BASE_URL", + "spring.ai.deepseek.chat.enabled=true") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, DeepSeekAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(DeepSeekChatProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(DeepSeekChatModel.class)).isNotEmpty(); + }); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/deepseek/tool/DeepSeekFunctionCallbackIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/deepseek/tool/DeepSeekFunctionCallbackIT.java new file mode 100644 index 00000000000..6a527817440 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/deepseek/tool/DeepSeekFunctionCallbackIT.java @@ -0,0 +1,124 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.deepseek.tool; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.autoconfigure.deepseek.DeepSeekAutoConfiguration; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.deepseek.DeepSeekChatModel; +import org.springframework.ai.deepseek.DeepSeekChatOptions; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import reactor.core.publisher.Flux; + +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Geng Rong + */ +@EnabledIfEnvironmentVariable(named = "DEEPSEEK_API_KEY", matches = ".*") +public class DeepSeekFunctionCallbackIT { + + private final Logger logger = LoggerFactory.getLogger(DeepSeekFunctionCallbackIT.class); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.deepseek.apiKey=" + System.getenv("DEEPSEEK_API_KEY")) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, DeepSeekAutoConfiguration.class)) + .withUserConfiguration(Config.class); + + @Test + void functionCallTest() { + this.contextRunner.run(context -> { + + DeepSeekChatModel chatModel = context.getBean(DeepSeekChatModel.class); + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius"); + + ChatResponse response = chatModel + .call(new Prompt(List.of(userMessage), DeepSeekChatOptions.builder().function("WeatherInfo").build())); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15"); + + }); + } + + @Test + void streamFunctionCallTest() { + this.contextRunner.run(context -> { + + DeepSeekChatModel chatModel = context.getBean(DeepSeekChatModel.class); + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius"); + + Flux response = chatModel.stream( + new Prompt(List.of(userMessage), DeepSeekChatOptions.builder().function("WeatherInfo").build())); + + String content = response.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getText) + .filter(Objects::nonNull) + .collect(Collectors.joining()); + logger.info("Response: {}", content); + + assertThat(content).containsAnyOf("30.0", "30"); + assertThat(content).containsAnyOf("10.0", "10"); + assertThat(content).containsAnyOf("15.0", "15"); + + }); + } + + @Configuration + static class Config { + + @Bean + public FunctionCallback weatherFunctionInfo() { + + return FunctionCallback.builder() + .function("WeatherInfo", new MockWeatherService()) + .description("Get the weather in location") + .inputType(MockWeatherService.Request.class) + .build(); + } + + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/deepseek/tool/FunctionCallbackInPromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/deepseek/tool/FunctionCallbackInPromptIT.java new file mode 100644 index 00000000000..5fb57abace5 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/deepseek/tool/FunctionCallbackInPromptIT.java @@ -0,0 +1,117 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.deepseek.tool; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.autoconfigure.deepseek.DeepSeekAutoConfiguration; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.deepseek.DeepSeekChatModel; +import org.springframework.ai.deepseek.DeepSeekChatOptions; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import reactor.core.publisher.Flux; + +import java.util.List; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Geng Rong + */ +@EnabledIfEnvironmentVariable(named = "DEEPSEEK_API_KEY", matches = ".*") +public class FunctionCallbackInPromptIT { + + private final Logger logger = LoggerFactory.getLogger(FunctionCallbackInPromptIT.class); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.deepseek.apiKey=" + System.getenv("DEEPSEEK_API_KEY")) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, DeepSeekAutoConfiguration.class)); + + @Test + void functionCallTest() { + this.contextRunner.run(context -> { + + DeepSeekChatModel chatModel = context.getBean(DeepSeekChatModel.class); + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius"); + + var promptOptions = DeepSeekChatOptions.builder() + .functionCallbacks(List.of(FunctionCallback.builder() + .function("CurrentWeatherService", new MockWeatherService()) + .description("Get the weather in location") + .inputType(MockWeatherService.Request.class) + .build())) + .build(); + + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15"); + }); + } + + @Test + void streamingFunctionCallTest() { + + this.contextRunner.run(context -> { + + DeepSeekChatModel chatModel = context.getBean(DeepSeekChatModel.class); + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius"); + + var promptOptions = DeepSeekChatOptions.builder() + .functionCallbacks(List.of(FunctionCallback.builder() + .function("CurrentWeatherService", new MockWeatherService()) + .description("Get the weather in location") + .inputType(MockWeatherService.Request.class) + .build())) + .build(); + + Flux response = chatModel.stream(new Prompt(List.of(userMessage), promptOptions)); + + String content = response.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getText) + .collect(Collectors.joining()); + logger.info("Response: {}", content); + + assertThat(content).containsAnyOf("30.0", "30"); + assertThat(content).containsAnyOf("10.0", "10"); + assertThat(content).containsAnyOf("15.0", "15"); + }); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/deepseek/tool/FunctionCallbackWithPlainFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/deepseek/tool/FunctionCallbackWithPlainFunctionBeanIT.java new file mode 100644 index 00000000000..0f939a7aea6 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/deepseek/tool/FunctionCallbackWithPlainFunctionBeanIT.java @@ -0,0 +1,175 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.deepseek.tool; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.autoconfigure.deepseek.DeepSeekAutoConfiguration; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.deepseek.DeepSeekChatModel; +import org.springframework.ai.deepseek.DeepSeekChatOptions; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Description; +import reactor.core.publisher.Flux; + +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Geng Rong + */ +@EnabledIfEnvironmentVariable(named = "DEEPSEEK_API_KEY", matches = ".*") +class FunctionCallbackWithPlainFunctionBeanIT { + + private final Logger logger = LoggerFactory.getLogger(FunctionCallbackWithPlainFunctionBeanIT.class); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.deepseek.apiKey=" + System.getenv("DEEPSEEK_API_KEY")) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, DeepSeekAutoConfiguration.class)) + .withUserConfiguration(Config.class); + + @Test + void functionCallTest() { + this.contextRunner.run(context -> { + + DeepSeekChatModel chatModel = context.getBean(DeepSeekChatModel.class); + + // Test weatherFunction + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius"); + + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), + DeepSeekChatOptions.builder().function("weatherFunction").build())); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15"); + + // Test weatherFunctionTwo + response = chatModel.call(new Prompt(List.of(userMessage), + DeepSeekChatOptions.builder().function("weatherFunctionTwo").build())); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15"); + + }); + } + + @Test + void functionCallWithPortableFunctionCallingOptions() { + this.contextRunner.run(context -> { + + DeepSeekChatModel chatModel = context.getBean(DeepSeekChatModel.class); + + // Test weatherFunction + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius"); + + FunctionCallingOptions functionOptions = FunctionCallingOptions.builder() + .function("weatherFunction") + .build(); + + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); + + logger.info("Response: {}", response); + }); + } + + @Test + void streamFunctionCallTest() { + this.contextRunner.run(context -> { + + DeepSeekChatModel chatModel = context.getBean(DeepSeekChatModel.class); + + // Test weatherFunction + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius"); + + Flux response = chatModel.stream(new Prompt(List.of(userMessage), + DeepSeekChatOptions.builder().function("weatherFunction").build())); + + String content = response.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getText) + .collect(Collectors.joining()); + logger.info("Response: {}", content); + + assertThat(content).containsAnyOf("30.0", "30"); + assertThat(content).containsAnyOf("10.0", "10"); + assertThat(content).containsAnyOf("15.0", "15"); + + // Test weatherFunctionTwo + response = chatModel.stream(new Prompt(List.of(userMessage), + DeepSeekChatOptions.builder().function("weatherFunctionTwo").build())); + + content = response.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getText) + .collect(Collectors.joining()); + logger.info("Response: {}", content); + + assertThat(content).containsAnyOf("30.0", "30"); + assertThat(content).containsAnyOf("10.0", "10"); + assertThat(content).containsAnyOf("15.0", "15"); + }); + } + + @Configuration + static class Config { + + @Bean + @Description("Get the weather in location") + public Function weatherFunction() { + return new MockWeatherService(); + } + + // Relies on the Request's JsonClassDescription annotation to provide the + // function description. + @Bean + public Function weatherFunctionTwo() { + MockWeatherService weatherService = new MockWeatherService(); + return (weatherService::apply); + } + + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/deepseek/tool/MockWeatherService.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/deepseek/tool/MockWeatherService.java new file mode 100644 index 00000000000..c186eb31602 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/deepseek/tool/MockWeatherService.java @@ -0,0 +1,95 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.deepseek.tool; + +import com.fasterxml.jackson.annotation.JsonClassDescription; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; + +import java.util.function.Function; + +/** + * Mock 3rd party weather service. + * + * @author Geng Rong + */ +public class MockWeatherService implements Function { + + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); + } + + /** + * Temperature units. + */ + public enum Unit { + + /** + * Celsius. + */ + C("metric"), + /** + * Fahrenheit. + */ + F("imperial"); + + /** + * Human readable unit name. + */ + public final String unitName; + + Unit(String text) { + this.unitName = text; + } + + } + + /** + * Weather Function request. + */ + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Weather API request") + public record Request(@JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + + } + + /** + * Weather Function response. + */ + public record Response(double temp, double feels_like, double temp_min, double temp_max, int pressure, int humidity, + Unit unit) { + + } + +} diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-deepseek/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-deepseek/pom.xml new file mode 100644 index 00000000000..4bd4986fbfc --- /dev/null +++ b/spring-ai-spring-boot-starters/spring-ai-starter-deepseek/pom.xml @@ -0,0 +1,58 @@ + + + + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-SNAPSHOT + ../../pom.xml + + spring-ai-deepseek-spring-boot-starter + jar + Spring AI Starter - DeepSeek + Spring AI DeepSeek Auto Configuration + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + org.springframework.boot + spring-boot-starter + + + + org.springframework.ai + spring-ai-spring-boot-autoconfigure + ${project.parent.version} + + + + org.springframework.ai + spring-ai-deepseek + ${project.parent.version} + + + +