Skip to content

Commit

Permalink
Advancing Tool Support - Part 3
Browse files Browse the repository at this point in the history
* Introduced ToolCallingManager to manage the tool calling activities for resolving and executing tools. A default implementation is provided. It can be used to handle explicit tool execution on the client-side, superseding the previous FunctionCallingHelper class. It’s ready to be instrumented via Micrometer, and support exception handling when tool calls fail.
* Introduced ToolCallExceptionConverter to handle exceptions in tool calling, and provided a default implementation propagating the error message to the chat morel.
* Introduced ToolCallbackResolver to resolve ToolCallback instances. A default implementation is provided (DelegatingToolCallbackResolver), capable of delegating the resolution to a series of resolvers, including static resolution (StaticToolCallbackResolver) and dynamic resolution from the Spring context (SpringBeanToolCallbackResolver).
* Improved configuration in ToolCallingChatOptions to enable/disable the tool execution within a ChatModel (superseding the previous proxyToolCalls option).
* Added unit and integration tests to cover all the new use cases and existing functionality which was not covered by autotests (tool resolution from Spring context).
* Deprecated FunctionCallbackResolver, AbstractToolCallSupport, and FunctionCallingHelper.

Relates to gh-2049

Signed-off-by: Thomas Vitale <[email protected]>
  • Loading branch information
ThomasVitale committed Jan 27, 2025
1 parent 2f14597 commit f678969
Show file tree
Hide file tree
Showing 46 changed files with 2,123 additions and 289 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackResolver;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

Expand All @@ -44,7 +45,9 @@
* @author Thomas Vitale
* @author Jihoon Kim
* @since 1.0.0
* @deprecated Use {@link ToolCallingManager} instead.
*/
@Deprecated
public abstract class AbstractToolCallSupport {

protected static final boolean IS_RUNTIME_CALL = true;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -33,6 +33,7 @@
* @author Soby Chacko
* @author John Blum
* @author Alexandros Pappas
* @author Thomas Vitale
*/
public class ChatResponse implements ModelResponse<Generation> {

Expand Down Expand Up @@ -100,6 +101,16 @@ public ChatResponseMetadata getMetadata() {
return this.chatResponseMetadata;
}

/**
* Whether the model has requested the execution of a tool.
*/
public boolean hasToolCalls() {
if (CollectionUtils.isEmpty(generations)) {
return false;
}
return generations.stream().anyMatch(generation -> generation.getOutput().hasToolCalls());
}

@Override
public String toString() {
return "ChatResponse [metadata=" + this.chatResponseMetadata + ", generations=" + this.generations + "]";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.model.function.FunctionCallback.SchemaType;
import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver;
import org.springframework.ai.tool.resolution.TypeResolverHelper;
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
Expand Down Expand Up @@ -55,7 +57,9 @@
* @author Christian Tzolov
* @author Christopher Smith
* @author Sebastien Deleuze
* @deprecated Use {@link SpringBeanToolCallbackResolver} instead.
*/
@Deprecated
public class DefaultFunctionCallbackResolver implements ApplicationContextAware, FunctionCallbackResolver {

private GenericApplicationContext applicationContext;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@

package org.springframework.ai.model.function;

import org.springframework.ai.tool.resolution.ToolCallbackResolver;
import org.springframework.lang.NonNull;

/**
* Strategy interface for resolving {@link FunctionCallback} instances.
*
* @author Christian Tzolov
* @since 1.0.0
* @deprecated Use {@link ToolCallbackResolver} instead.
*/
@Deprecated
public interface FunctionCallbackResolver {

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.Set;
import java.util.function.Function;

import org.springframework.ai.model.tool.ToolCallingManager;
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.messages.AssistantMessage;
Expand All @@ -40,7 +41,10 @@
* Helper class that reuses the {@link AbstractToolCallSupport} to implement the function
* call handling logic on the client side. Used when the withProxyToolCalls(true) option
* is enabled.
*
* @deprecated Use {@link ToolCallingManager} instead.
*/
@Deprecated
public class FunctionCallingHelper extends AbstractToolCallSupport {

public FunctionCallingHelper() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,11 @@

import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
Expand All @@ -39,14 +37,14 @@
*/
public class DefaultToolCallingChatOptions implements ToolCallingChatOptions {

private List<ToolCallback> toolCallbacks = new ArrayList<>();
private List<FunctionCallback> toolCallbacks = new ArrayList<>();

private Set<String> tools = new HashSet<>();

private Map<String, Object> toolContext = new HashMap<>();

@Nullable
private Boolean toolCallReturnDirect;
private Boolean toolExecutionEnabled;

@Nullable
private String model;
Expand All @@ -73,23 +71,17 @@ public class DefaultToolCallingChatOptions implements ToolCallingChatOptions {
private Double topP;

@Override
public List<ToolCallback> getToolCallbacks() {
public List<FunctionCallback> getToolCallbacks() {
return List.copyOf(this.toolCallbacks);
}

@Override
public void setToolCallbacks(List<ToolCallback> toolCallbacks) {
public void setToolCallbacks(List<FunctionCallback> toolCallbacks) {
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");
this.toolCallbacks = new ArrayList<>(toolCallbacks);
}

@Override
public void setToolCallbacks(ToolCallback... toolCallbacks) {
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
setToolCallbacks(List.of(toolCallbacks));
}

@Override
public Set<String> getTools() {
return Set.copyOf(this.tools);
Expand All @@ -103,12 +95,6 @@ public void setTools(Set<String> tools) {
this.tools = new HashSet<>(tools);
}

@Override
public void setTools(String... tools) {
Assert.notNull(tools, "tools cannot be null");
setTools(Set.of(tools));
}

@Override
public Map<String, Object> getToolContext() {
return Map.copyOf(this.toolContext);
Expand All @@ -123,23 +109,23 @@ public void setToolContext(Map<String, Object> toolContext) {

@Override
@Nullable
public Boolean getToolCallReturnDirect() {
return this.toolCallReturnDirect;
public Boolean isToolExecutionEnabled() {
return this.toolExecutionEnabled;
}

@Override
public void setToolCallReturnDirect(@Nullable Boolean toolCallReturnDirect) {
this.toolCallReturnDirect = toolCallReturnDirect;
public void setToolExecutionEnabled(@Nullable Boolean toolExecutionEnabled) {
this.toolExecutionEnabled = toolExecutionEnabled;
}

@Override
public List<FunctionCallback> getFunctionCallbacks() {
return getToolCallbacks().stream().map(FunctionCallback.class::cast).toList();
return getToolCallbacks();
}

@Override
public void setFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
throw new UnsupportedOperationException("Not supported. Call setToolCallbacks instead.");
setToolCallbacks(functionCallbacks);
}

@Override
Expand All @@ -155,12 +141,12 @@ public void setFunctions(Set<String> functions) {
@Override
@Nullable
public Boolean getProxyToolCalls() {
return getToolCallReturnDirect();
return isToolExecutionEnabled() != null ? !isToolExecutionEnabled() : null;
}

@Override
public void setProxyToolCalls(@Nullable Boolean proxyToolCalls) {
setToolCallReturnDirect(proxyToolCalls != null && proxyToolCalls);
setToolExecutionEnabled(proxyToolCalls == null || !proxyToolCalls);
}

@Override
Expand Down Expand Up @@ -250,7 +236,7 @@ public <T extends ChatOptions> T copy() {
options.setToolCallbacks(getToolCallbacks());
options.setTools(getTools());
options.setToolContext(getToolContext());
options.setToolCallReturnDirect(getToolCallReturnDirect());
options.setToolExecutionEnabled(isToolExecutionEnabled());
options.setModel(getModel());
options.setFrequencyPenalty(getFrequencyPenalty());
options.setMaxTokens(getMaxTokens());
Expand All @@ -262,55 +248,6 @@ public <T extends ChatOptions> T copy() {
return (T) options;
}

/**
* Merge the given {@link ChatOptions} into this instance.
*/
public ToolCallingChatOptions merge(ChatOptions options) {
ToolCallingChatOptions.Builder builder = ToolCallingChatOptions.builder();
builder.model(StringUtils.hasText(options.getModel()) ? options.getModel() : this.getModel());
builder.frequencyPenalty(
options.getFrequencyPenalty() != null ? options.getFrequencyPenalty() : this.getFrequencyPenalty());
builder.maxTokens(options.getMaxTokens() != null ? options.getMaxTokens() : this.getMaxTokens());
builder.presencePenalty(
options.getPresencePenalty() != null ? options.getPresencePenalty() : this.getPresencePenalty());
builder.stopSequences(options.getStopSequences() != null ? new ArrayList<>(options.getStopSequences())
: this.getStopSequences());
builder.temperature(options.getTemperature() != null ? options.getTemperature() : this.getTemperature());
builder.topK(options.getTopK() != null ? options.getTopK() : this.getTopK());
builder.topP(options.getTopP() != null ? options.getTopP() : this.getTopP());

if (options instanceof ToolCallingChatOptions toolOptions) {
List<ToolCallback> toolCallbacks = new ArrayList<>(this.toolCallbacks);
if (!CollectionUtils.isEmpty(toolOptions.getToolCallbacks())) {
toolCallbacks.addAll(toolOptions.getToolCallbacks());
}
builder.toolCallbacks(toolCallbacks);

Set<String> tools = new HashSet<>(this.tools);
if (!CollectionUtils.isEmpty(toolOptions.getTools())) {
tools.addAll(toolOptions.getTools());
}
builder.tools(tools);

Map<String, Object> toolContext = new HashMap<>(this.toolContext);
if (!CollectionUtils.isEmpty(toolOptions.getToolContext())) {
toolContext.putAll(toolOptions.getToolContext());
}
builder.toolContext(toolContext);

builder.toolCallReturnDirect(toolOptions.getToolCallReturnDirect() != null
? toolOptions.getToolCallReturnDirect() : this.getToolCallReturnDirect());
}
else {
builder.toolCallbacks(this.toolCallbacks);
builder.tools(this.tools);
builder.toolContext(this.toolContext);
builder.toolCallReturnDirect(this.toolCallReturnDirect);
}

return builder.build();
}

public static Builder builder() {
return new Builder();
}
Expand All @@ -323,14 +260,15 @@ public static class Builder implements ToolCallingChatOptions.Builder {
private final DefaultToolCallingChatOptions options = new DefaultToolCallingChatOptions();

@Override
public ToolCallingChatOptions.Builder toolCallbacks(List<ToolCallback> toolCallbacks) {
public ToolCallingChatOptions.Builder toolCallbacks(List<FunctionCallback> toolCallbacks) {
this.options.setToolCallbacks(toolCallbacks);
return this;
}

@Override
public ToolCallingChatOptions.Builder toolCallbacks(ToolCallback... toolCallbacks) {
this.options.setToolCallbacks(toolCallbacks);
public ToolCallingChatOptions.Builder toolCallbacks(FunctionCallback... toolCallbacks) {
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
this.options.setToolCallbacks(Arrays.asList(toolCallbacks));
return this;
}

Expand All @@ -342,7 +280,8 @@ public ToolCallingChatOptions.Builder tools(Set<String> toolNames) {

@Override
public ToolCallingChatOptions.Builder tools(String... toolNames) {
this.options.setTools(toolNames);
Assert.notNull(toolNames, "toolNames cannot be null");
this.options.setTools(Set.of(toolNames));
return this;
}

Expand All @@ -363,16 +302,15 @@ public ToolCallingChatOptions.Builder toolContext(String key, Object value) {
}

@Override
public ToolCallingChatOptions.Builder toolCallReturnDirect(@Nullable Boolean toolCallReturnDirect) {
this.options.setToolCallReturnDirect(toolCallReturnDirect);
public ToolCallingChatOptions.Builder toolExecutionEnabled(@Nullable Boolean toolExecutionEnabled) {
this.options.setToolExecutionEnabled(toolExecutionEnabled);
return this;
}

@Override
@Deprecated // Use toolCallbacks() instead
public ToolCallingChatOptions.Builder functionCallbacks(List<FunctionCallback> functionCallbacks) {
Assert.notNull(functionCallbacks, "functionCallbacks cannot be null");
return toolCallbacks(functionCallbacks.stream().map(ToolCallback.class::cast).toList());
return toolCallbacks(functionCallbacks);
}

@Override
Expand All @@ -395,9 +333,9 @@ public ToolCallingChatOptions.Builder function(String function) {
}

@Override
@Deprecated // Use toolCallReturnDirect() instead
@Deprecated // Use toolExecutionEnabled() instead
public ToolCallingChatOptions.Builder proxyToolCalls(@Nullable Boolean proxyToolCalls) {
return toolCallReturnDirect(proxyToolCalls != null && proxyToolCalls);
return toolExecutionEnabled(proxyToolCalls == null || !proxyToolCalls);
}

@Override
Expand Down
Loading

0 comments on commit f678969

Please sign in to comment.