diff --git a/camel/agents/chat_agent.py b/camel/agents/chat_agent.py index 8621a94483..8a76d4259f 100644 --- a/camel/agents/chat_agent.py +++ b/camel/agents/chat_agent.py @@ -63,7 +63,7 @@ from openai import Stream from camel.terminators import ResponseTerminator - from camel.toolkits import OpenAIFunction + from camel.toolkits import FunctionTool logger = logging.getLogger(__name__) @@ -115,7 +115,8 @@ class ChatAgent(BaseAgent): r"""Class for managing conversations of CAMEL Chat Agents. Args: - system_message (BaseMessage): The system message for the chat agent. + system_message (BaseMessage, optional): The system message for the + chat agent. model (BaseModelBackend, optional): The model backend to use for generating responses. (default: :obj:`OpenAIModel` with `GPT_4O_MINI`) @@ -131,10 +132,10 @@ class ChatAgent(BaseAgent): (default: :obj:`None`) output_language (str, optional): The language to be output by the agent. (default: :obj:`None`) - tools (List[OpenAIFunction], optional): List of available - :obj:`OpenAIFunction`. (default: :obj:`None`) - external_tools (List[OpenAIFunction], optional): List of external tools - (:obj:`OpenAIFunction`) bind to one chat agent. When these tools + tools (List[FunctionTool], optional): List of available + :obj:`FunctionTool`. (default: :obj:`None`) + external_tools (List[FunctionTool], optional): List of external tools + (:obj:`FunctionTool`) bind to one chat agent. When these tools are called, the agent will directly return the request instead of processing it. (default: :obj:`None`) response_terminators (List[ResponseTerminator], optional): List of @@ -144,20 +145,24 @@ class ChatAgent(BaseAgent): def __init__( self, - system_message: BaseMessage, + system_message: Optional[BaseMessage] = None, model: Optional[BaseModelBackend] = None, memory: Optional[AgentMemory] = None, message_window_size: Optional[int] = None, token_limit: Optional[int] = None, output_language: Optional[str] = None, - tools: Optional[List[OpenAIFunction]] = None, - external_tools: Optional[List[OpenAIFunction]] = None, + tools: Optional[List[FunctionTool]] = None, + external_tools: Optional[List[FunctionTool]] = None, response_terminators: Optional[List[ResponseTerminator]] = None, ) -> None: - self.orig_sys_message: BaseMessage = system_message - self.system_message = system_message - self.role_name: str = system_message.role_name - self.role_type: RoleType = system_message.role_type + self.orig_sys_message: Optional[BaseMessage] = system_message + self._system_message: Optional[BaseMessage] = system_message + self.role_name: str = ( + getattr(system_message, 'role_name', None) or "assistant" + ) + self.role_type: RoleType = ( + getattr(system_message, 'role_type', None) or RoleType.ASSISTANT + ) self.model_backend: BaseModelBackend = ( model if model is not None @@ -272,11 +277,12 @@ def reset(self): terminator.reset() @property - def system_message(self) -> BaseMessage: + def system_message(self) -> Optional[BaseMessage]: r"""The getter method for the property :obj:`system_message`. Returns: - BaseMessage: The system message of this agent. + Optional[BaseMessage]: The system message of this agent if set, + else :obj:`None`. """ return self._system_message @@ -327,12 +333,22 @@ def set_output_language(self, output_language: str) -> BaseMessage: BaseMessage: The updated system message object. """ self.output_language = output_language - content = self.orig_sys_message.content + ( + language_prompt = ( "\nRegardless of the input language, " f"you must output text in {output_language}." ) - self.system_message = self.system_message.create_new_instance(content) - return self.system_message + if self.orig_sys_message is not None: + content = self.orig_sys_message.content + language_prompt + self._system_message = self.orig_sys_message.create_new_instance( + content + ) + return self._system_message + else: + self._system_message = BaseMessage.make_assistant_message( + role_name="Assistant", + content=language_prompt, + ) + return self._system_message def get_info( self, @@ -377,12 +393,15 @@ def init_messages(self) -> None: r"""Initializes the stored messages list with the initial system message. """ - system_record = MemoryRecord( - message=self.system_message, - role_at_backend=OpenAIBackendRole.SYSTEM, - ) - self.memory.clear() - self.memory.write_record(system_record) + if self.orig_sys_message is not None: + system_record = MemoryRecord( + message=self.orig_sys_message, + role_at_backend=OpenAIBackendRole.SYSTEM, + ) + self.memory.clear() + self.memory.write_record(system_record) + else: + self.memory.clear() def record_message(self, message: BaseMessage) -> None: r"""Records the externally provided message into the agent memory as if @@ -795,12 +814,12 @@ def _structure_output_with_function( r"""Internal function of structuring the output of the agent based on the given output schema. """ - from camel.toolkits import OpenAIFunction + from camel.toolkits import FunctionTool schema_json = get_pydantic_object_schema(output_schema) func_str = json_to_function_code(schema_json) func_callable = func_string_to_callable(func_str) - func = OpenAIFunction(func_callable) + func = FunctionTool(func_callable) original_func_dict = self.func_dict original_model_dict = self.model_backend.model_config_dict diff --git a/camel/configs/base_config.py b/camel/configs/base_config.py index 3d6c95bdc9..8ce9ff24a5 100644 --- a/camel/configs/base_config.py +++ b/camel/configs/base_config.py @@ -39,13 +39,13 @@ class BaseConfig(ABC, BaseModel): @classmethod def fields_type_checking(cls, tools): if tools is not None: - from camel.toolkits import OpenAIFunction + from camel.toolkits import FunctionTool for tool in tools: - if not isinstance(tool, OpenAIFunction): + if not isinstance(tool, FunctionTool): raise ValueError( f"The tool {tool} should " - "be an instance of `OpenAIFunction`." + "be an instance of `FunctionTool`." ) return tools @@ -54,14 +54,14 @@ def as_dict(self) -> dict[str, Any]: tools_schema = None if self.tools: - from camel.toolkits import OpenAIFunction + from camel.toolkits import FunctionTool tools_schema = [] for tool in self.tools: - if not isinstance(tool, OpenAIFunction): + if not isinstance(tool, FunctionTool): raise ValueError( f"The tool {tool} should " - "be an instance of `OpenAIFunction`." + "be an instance of `FunctionTool`." ) tools_schema.append(tool.get_openai_tool_schema()) config_dict["tools"] = tools_schema diff --git a/camel/configs/groq_config.py b/camel/configs/groq_config.py index fe3f2f90ab..43fe2e076f 100644 --- a/camel/configs/groq_config.py +++ b/camel/configs/groq_config.py @@ -73,7 +73,7 @@ class GroqConfig(BaseConfig): user (str, optional): A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. (default: :obj:`""`) - tools (list[OpenAIFunction], optional): A list of tools the model may + tools (list[FunctionTool], optional): A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of functions the model may generate JSON inputs for. A max of 128 functions are supported. diff --git a/camel/configs/openai_config.py b/camel/configs/openai_config.py index 1adb8cc3fe..971232ad8f 100644 --- a/camel/configs/openai_config.py +++ b/camel/configs/openai_config.py @@ -81,7 +81,7 @@ class ChatGPTConfig(BaseConfig): user (str, optional): A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. (default: :obj:`""`) - tools (list[OpenAIFunction], optional): A list of tools the model may + tools (list[FunctionTool], optional): A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of functions the model may generate JSON inputs for. A max of 128 functions are supported. diff --git a/camel/configs/samba_config.py b/camel/configs/samba_config.py index 444172270f..409f939cfd 100644 --- a/camel/configs/samba_config.py +++ b/camel/configs/samba_config.py @@ -172,7 +172,7 @@ class SambaCloudAPIConfig(BaseConfig): user (str, optional): A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. (default: :obj:`""`) - tools (list[OpenAIFunction], optional): A list of tools the model may + tools (list[FunctionTool], optional): A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of functions the model may generate JSON inputs for. A max of 128 functions are supported. diff --git a/camel/configs/zhipuai_config.py b/camel/configs/zhipuai_config.py index 89f2061dfb..21bbec1e7f 100644 --- a/camel/configs/zhipuai_config.py +++ b/camel/configs/zhipuai_config.py @@ -45,7 +45,7 @@ class ZhipuAIConfig(BaseConfig): in the chat completion. The total length of input tokens and generated tokens is limited by the model's context length. (default: :obj:`None`) - tools (list[OpenAIFunction], optional): A list of tools the model may + tools (list[FunctionTool], optional): A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of functions the model may generate JSON inputs for. A max of 128 functions are supported. diff --git a/camel/messages/__init__.py b/camel/messages/__init__.py index 6e93b5c437..870626024c 100644 --- a/camel/messages/__init__.py +++ b/camel/messages/__init__.py @@ -32,6 +32,7 @@ 'OpenAISystemMessage', 'OpenAIAssistantMessage', 'OpenAIUserMessage', + 'OpenAIFunctionMessage', 'OpenAIMessage', 'BaseMessage', 'FunctionCallingMessage', diff --git a/camel/societies/babyagi_playing.py b/camel/societies/babyagi_playing.py index 45c5c4f72a..cf9feacfbd 100644 --- a/camel/societies/babyagi_playing.py +++ b/camel/societies/babyagi_playing.py @@ -106,7 +106,7 @@ def __init__( ) self.assistant_agent: ChatAgent - self.assistant_sys_msg: BaseMessage + self.assistant_sys_msg: Optional[BaseMessage] self.task_creation_agent: TaskCreationAgent self.task_prioritization_agent: TaskPrioritizationAgent self.init_agents( @@ -202,7 +202,8 @@ def init_agents( self.task_creation_agent = TaskCreationAgent( objective=self.specified_task_prompt, - role_name=self.assistant_sys_msg.role_name, + role_name=getattr(self.assistant_sys_msg, 'role_name', None) + or "assistant", output_language=output_language, message_window_size=message_window_size, **(task_creation_agent_kwargs or {}), @@ -238,7 +239,9 @@ def step(self) -> ChatAgentResponse: task_name = self.subtasks.popleft() assistant_msg_msg = BaseMessage.make_user_message( - role_name=self.assistant_sys_msg.role_name, content=f"{task_name}" + role_name=getattr(self.assistant_sys_msg, 'role_name', None) + or "assistant", + content=f"{task_name}", ) assistant_response = self.assistant_agent.step(assistant_msg_msg) diff --git a/camel/societies/role_playing.py b/camel/societies/role_playing.py index a5c16edbd2..86cd4c4244 100644 --- a/camel/societies/role_playing.py +++ b/camel/societies/role_playing.py @@ -149,8 +149,8 @@ def __init__( self.assistant_agent: ChatAgent self.user_agent: ChatAgent - self.assistant_sys_msg: BaseMessage - self.user_sys_msg: BaseMessage + self.assistant_sys_msg: Optional[BaseMessage] + self.user_sys_msg: Optional[BaseMessage] self._init_agents( init_assistant_sys_msg, init_user_sys_msg, @@ -454,9 +454,11 @@ def init_chat(self, init_msg_content: Optional[str] = None) -> BaseMessage: ) if init_msg_content is None: init_msg_content = default_init_msg_content + # Initialize a message sent by the assistant init_msg = BaseMessage.make_assistant_message( - role_name=self.assistant_sys_msg.role_name, + role_name=getattr(self.assistant_sys_msg, 'role_name', None) + or "assistant", content=init_msg_content, ) diff --git a/camel/tasks/machine.py b/camel/tasks/machine.py new file mode 100644 index 0000000000..30c116cf89 --- /dev/null +++ b/camel/tasks/machine.py @@ -0,0 +1,180 @@ +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +from typing import Dict, List + + +class Machine: + r"""A state machine allowing transitions between different states based on + triggers. + + Args: + states (List[str]): A list of states in the machine. + transitions (List[dict]): A list of transitions, where each transition + is represented as a dictionary containing 'trigger', 'source', and + 'dest'. + initial (str): The initial state of the machine. + + Attributes: + states (List[str]): List of states within the state machine. + transitions (List[dict]): List of state transitions. + initial_state (str): The initial state. + current_state (str): The current active state of the machine. + transition_map (Dict[str, Dict[str, str]]): Mapping of triggers to + source-destination state pairs. + """ + + def __init__( + self, states: List[str], transitions: List[dict], initial: str + ): + self.states = states + self.transitions = transitions + self.initial_state = initial + self.current_state = initial + + if self.initial_state not in self.states: + raise ValueError( + f"Initial state '{self.initial_state}' must be in the states." + ) + + self.transition_map = self._build_transition_map() + + def _build_transition_map(self) -> Dict[str, Dict[str, str]]: + r"""Constructs a mapping from triggers to state transitions. + + Returns: + Dict[str, Dict[str, str]]: A nested dictionary where each key is a + trigger, mapping to another dictionary of source-destination + state pairs. + """ + transition_map: Dict[str, Dict[str, str]] = {} + for transition in self.transitions: + trigger = transition['trigger'] + source = transition['source'] + dest = transition['dest'] + + if trigger not in transition_map: + transition_map[trigger] = {} + + transition_map[trigger][source] = dest + + return transition_map + + def add_state(self, state: str): + r"""Adds a new state to the machine if it doesn't already exist. + + Args: + state (str): The state to be added. + """ + if state not in self.states: + self.states.append(state) + + def add_transition(self, trigger: str, source: str, dest: str): + r"""Adds a new transition between states based on a trigger. + + Args: + trigger (str): The trigger for the transition. + source (str): The source state. + dest (str): The destination state. + + Raises: + ValueError: If either the source or destination state is invalid. + """ + if source not in self.states or dest not in self.states: + raise ValueError( + f"Both source '{source}' and destination '{dest}' must be " + + "valid states." + ) + + if trigger not in self.transition_map: + self.transition_map[trigger] = {} + + self.transition_map[trigger][source] = dest + + def trigger(self, trigger: str): + r"""Executes a state transition based on a given trigger. + + Args: + trigger (str): The trigger to initiate the state transition. + + Raises: + ValueError: If there is no valid transition for the given trigger + from the current state. + """ + if ( + trigger in self.transition_map + and self.current_state in self.transition_map[trigger] + ): + new_state = self.transition_map[trigger][self.current_state] + self.current_state = new_state + else: + print( + f"No valid transition for trigger '{trigger}' from state '" + + f"{self.current_state}'." + ) + + def reset(self): + r"""Resets the machine to its initial state.""" + self.current_state = self.initial_state + + def get_current_state(self) -> str: + r"""Gets the current active state of the machine. + + Returns: + str: The current state. + """ + return self.current_state + + def get_available_triggers(self) -> List[str]: + r"""Lists triggers available from the current state. + + Returns: + List[str]: List of available triggers. + """ + return [ + trigger + for trigger, transitions in self.transition_map.items() + if self.current_state in transitions + ] + + def remove_state(self, state: str): + r"""Removes a state from the machine and any transitions involving it. + + Args: + state (str): The state to be removed. + """ + if state in self.states: + self.states.remove(state) + # Remove transitions associated with the state + self.transition_map = { + trigger: { + src: dst + for src, dst in transitions.items() + if src != state and dst != state + } + for trigger, transitions in self.transition_map.items() + } + + def remove_transition(self, trigger: str, source: str): + r"""Removes a specific transition from a given source state based on a + trigger. + + Args: + trigger (str): The trigger associated with the transition. + source (str): The source state. + """ + if ( + trigger in self.transition_map + and source in self.transition_map[trigger] + ): + del self.transition_map[trigger][source] diff --git a/camel/tasks/task.py b/camel/tasks/task.py index 2521f5b0ea..a3018ea5a8 100644 --- a/camel/tasks/task.py +++ b/camel/tasks/task.py @@ -14,13 +14,15 @@ import re from enum import Enum -from typing import Callable, Dict, List, Literal, Optional, Union +from typing import Callable, Dict, List, Literal, Optional, TypedDict, Union -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from camel.agents import ChatAgent from camel.messages import BaseMessage from camel.prompts import TextPrompt +from camel.tasks.machine import Machine +from camel.toolkits.function_tool import FunctionTool from .task_prompt import ( TASK_COMPOSE_PROMPT, @@ -75,8 +77,9 @@ class Task(BaseModel): state: The state which should be OPEN, RUNNING, DONE or DELETED. type: task type parent: The parent task, None for root task. - subtasks: The childrent sub-tasks for the task. + subtasks: The children sub-tasks for the task. result: The answer for the task. + onTaskComplete: The callback functions when the task is completed. """ content: str @@ -97,6 +100,8 @@ class Task(BaseModel): additional_info: Optional[str] = None + onTaskComplete: List[Callable[[], None]] = [] + @classmethod def from_message(cls, message: BaseMessage) -> "Task": r"""Create a task from a message. @@ -119,6 +124,7 @@ def reset(self): r"""Reset Task to initial state.""" self.state = TaskState.OPEN self.result = "" + self.onTaskComplete = [] def update_result(self, result: str): r"""Set task result and mark the task as DONE. @@ -128,6 +134,23 @@ def update_result(self, result: str): """ self.result = result self.set_state(TaskState.DONE) + for callback in self.onTaskComplete or []: + callback() + + def add_on_task_complete( + self, callbacks: Union[Callable[[], None], List[Callable[[], None]]] + ): + r"""Adds one or more callbacks to the onTaskComplete list. + + Args: + callbacks (Union[Callable[[], None], List[Callable[[], None]]]): + A single callback function or a list of callback functions to + be added. + """ + if isinstance(callbacks, list): + self.onTaskComplete.extend(callbacks) + else: + self.onTaskComplete.append(callbacks) def set_id(self, id: str): self.id = id @@ -289,8 +312,9 @@ class TaskManager: def __init__(self, task: Task): self.root_task: Task = task self.current_task_id: str = task.id - self.tasks: List[Task] = [task] - self.task_map: Dict[str, Task] = {task.id: task} + self.tasks: List[Task] = [] + self.task_map: Dict[str, Task] = {} + self.add_tasks(task) def gen_task_id(self) -> str: r"""Generate a new task id.""" @@ -370,13 +394,49 @@ def add_tasks(self, tasks: Union[Task, List[Task]]) -> None: r"""self.tasks and self.task_map will be updated by the input tasks.""" if not tasks: return - if not isinstance(tasks, List): - tasks = [tasks] + tasks = tasks if isinstance(tasks, list) else [tasks] + for task in tasks: - assert not self.exist(task.id), f"`{task.id}` already existed." + assert not self.exist(task.id), f"`{task.id}` already exists." + # Add callback to update the current task ID + task.add_on_task_complete( + self.create_update_current_task_id_callback(task) + ) self.tasks = self.topological_sort(self.tasks + tasks) self.task_map = {task.id: task for task in self.tasks} + def create_update_current_task_id_callback( + self, task: Task + ) -> Callable[[], None]: + r""" + Creates a callback function to update `current_task_id` after the + specified task is completed. + + The returned callback function, when called, sets `current_task_id` to + the ID of the next task in the `tasks` list. If the specified task is + the last one in the list, `current_task_id` is set to an empty string. + + Args: + task (Task): The task after which `current_task_id` should be + updated. + + Returns: + Callable[[], None]: + A callback function that, when invoked, updates + `current_task_id` to the next task's ID or clears it if the + specified task is the last in the list. + """ + + def update_current_task_id(): + current_index = self.tasks.index(task) + + if current_index + 1 < len(self.tasks): + self.current_task_id = self.tasks[current_index + 1].id + else: + self.current_task_id = "" + + return update_current_task_id + def evolve( self, task: Task, @@ -413,3 +473,173 @@ def evolve( if tasks: return tasks[0] return None + + +class FunctionToolState(BaseModel): + """Represents a specific state of a function tool within a state machine. + + Attributes: + name (str): The unique name of the state, serving as an identifier. + tools_space (Optional[List[FunctionTool]]): A list of `FunctionTool` + objects associated with this state. Defaults to an empty list. + """ + + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) + + name: str + tools_space: Optional[List[FunctionTool]] = [] + + def __eq__(self, other): + if isinstance(other, FunctionToolState): + return self.name == other.name + return False + + def __hash__(self): + return hash(self.name) + + def __str__(self): + return self.name + + def __repr__(self): + return self.name + + +class FunctionToolTransition(TypedDict): + """Defines a transition within a function tool state machine. + + Attributes: + trigger (Task): The task that initiates this transition. + source (str): The name of the starting state. + dest (str): The name of the target state after the transition. + """ + + trigger: Task + source: FunctionToolState + dest: FunctionToolState + + +class TaskManagerWithState(TaskManager): + r""" + A TaskManager with an integrated state machine supporting conditional + function calls based on task states and transitions. + + Args: + task (Task): The primary task to manage. + states (List[FunctionToolState]): A list of all possible states. + initial_state (str): The initial state of the machine. + transitions (Optional[List[FunctionToolTransition]]): A list of state + transitions, each containing a trigger, source state, and + destination state. + + Attributes: + state_space (List[FunctionToolState]): Collection of all states. + machine (Machine): Manages state transitions and tracks the current + state. + """ + + def __init__( + self, + task, + initial_state: FunctionToolState, + states: List[FunctionToolState], + transitions: Optional[List[FunctionToolTransition]], + ): + super().__init__(task) + self.state_space = states + self.machine = Machine( + states=[state.name for state in states], + transitions=[], + initial=initial_state.name, + ) + + for transition in transitions or []: + self.add_transition( + transition['trigger'], + transition['source'], + transition['dest'], + ) + + @property + def current_state(self) -> Optional[FunctionToolState]: + r""" + Retrieves the current state object based on the machine's active state + name. + + Returns: + Optional[FunctionToolState]: + The state object with a `name` matching `self.machine. + current_state`, or `None` if no match is found. + """ + for state in self.state_space: + if state.name == self.machine.current_state: + return state + return None + + def add_state(self, state: FunctionToolState): + r""" + Adds a new state to both `state_space` and the machine. + + Args: + state (FunctionToolState): The state object to be added. + """ + if state not in self.state_space: + self.state_space.append(state) + self.machine.add_state(state.name) + + def add_transition( + self, trigger: Task, source: FunctionToolState, dest: FunctionToolState + ): + r""" + Adds a new transition to the state machine. + + Args: + trigger (Task): The task that initiates this transition. + source (FunctionToolState): The source state. + dest (FunctionToolState): The destination state after the + transition. + """ + trigger.add_on_task_complete(lambda: self.machine.trigger(trigger.id)) + if not self.exist(trigger.id): + self.add_tasks(trigger) + self.machine.add_transition(trigger.id, source.name, dest.name) + + def get_current_tools(self) -> Optional[List[FunctionTool]]: + r""" + Retrieves the tools available in the current state. + + Returns: + Optional[List[FunctionTool]]: The tools associated with the + current state. + """ + return self.current_state.tools_space + + def remove_state(self, state_name: str): + r""" + Removes a state from `state_space` and the machine. + + Args: + state_name (str): The name of the state to remove. + """ + self.state_space = [ + state for state in self.state_space if state.name != state_name + ] + self.machine.remove_state(state_name) + + def remove_transition(self, trigger: Task, source: FunctionToolState): + r""" + Removes a transition from the machine. + + Args: + trigger (Task): The task associated with the transition to remove. + source (FunctionToolState): The source state. + """ + self.machine.remove_transition(trigger.id, source.name) + + def reset(self): + r""" + Resets the machine to its initial state. + """ + self.machine.reset() diff --git a/camel/toolkits/__init__.py b/camel/toolkits/__init__.py index 87f2326ff3..db9b92387b 100644 --- a/camel/toolkits/__init__.py +++ b/camel/toolkits/__init__.py @@ -12,7 +12,8 @@ # limitations under the License. # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== # ruff: noqa: I001 -from .openai_function import ( +from .function_tool import ( + FunctionTool, OpenAIFunction, get_openai_function_schema, get_openai_tool_schema, @@ -33,8 +34,10 @@ from .code_execution import CodeExecutionToolkit from .github_toolkit import GithubToolkit +from .toolkits_manager import ToolkitManager __all__ = [ + 'FunctionTool', 'OpenAIFunction', 'get_openai_function_schema', 'get_openai_tool_schema', @@ -56,4 +59,5 @@ 'SEARCH_FUNCS', 'WEATHER_FUNCS', 'DALLE_FUNCS', + 'ToolkitManager', ] diff --git a/camel/toolkits/base.py b/camel/toolkits/base.py index 6e6ee2f6ba..3bc8498773 100644 --- a/camel/toolkits/base.py +++ b/camel/toolkits/base.py @@ -12,13 +12,30 @@ # limitations under the License. # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +import inspect from typing import List from camel.utils import AgentOpsMeta -from .openai_function import OpenAIFunction +from .function_tool import FunctionTool class BaseToolkit(metaclass=AgentOpsMeta): - def get_tools(self) -> List[OpenAIFunction]: - raise NotImplementedError("Subclasses must implement this method.") + def get_tools(self) -> List[FunctionTool]: + """Returns a list of FunctionTool objects representing the + functions in the toolkit. + + Returns: + List[FunctionTool]: A list of FunctionTool objects + representing the functions in the toolkit. + """ + tools = [] + for _, method in inspect.getmembers(self, predicate=inspect.ismethod): + if getattr(method, '_is_exported', False): + tools.append( + FunctionTool( + func=method, name_prefix=self.__class__.__name__ + ) + ) + + return tools diff --git a/camel/toolkits/code_execution.py b/camel/toolkits/code_execution.py index 0fcc3243b2..f841d7362c 100644 --- a/camel/toolkits/code_execution.py +++ b/camel/toolkits/code_execution.py @@ -11,10 +11,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== -from typing import List, Literal +from typing import Literal from camel.interpreters import InternalPythonInterpreter -from camel.toolkits import OpenAIFunction +from camel.utils.commons import export_to_toolkit from .base import BaseToolkit @@ -42,6 +42,7 @@ def __init__( f"The sandbox type `{sandbox}` is not supported." ) + @export_to_toolkit def execute_code(self, code: str) -> str: r"""Execute a given code snippet. @@ -57,13 +58,3 @@ def execute_code(self, code: str) -> str: if self.verbose: print(content) return content - - def get_tools(self) -> List[OpenAIFunction]: - r"""Returns a list of OpenAIFunction objects representing the - functions in the toolkit. - - Returns: - List[OpenAIFunction]: A list of OpenAIFunction objects - representing the functions in the toolkit. - """ - return [OpenAIFunction(self.execute_code)] diff --git a/camel/toolkits/dalle_toolkit.py b/camel/toolkits/dalle_toolkit.py index 4898f0d691..cfe2596004 100644 --- a/camel/toolkits/dalle_toolkit.py +++ b/camel/toolkits/dalle_toolkit.py @@ -20,127 +20,130 @@ from openai import OpenAI from PIL import Image -from camel.toolkits import OpenAIFunction +from camel.toolkits import FunctionTool from camel.toolkits.base import BaseToolkit +from camel.utils.commons import export_to_toolkit -class DalleToolkit(BaseToolkit): - r"""A class representing a toolkit for image generation using OpenAI's. +def _base64_to_image(base64_string: str) -> Optional[Image.Image]: + r"""Converts a base64 encoded string into a PIL Image object. - This class provides methods handle image generation using OpenAI's DALL-E. + Args: + base64_string (str): The base64 encoded string of the image. + + Returns: + Optional[Image.Image]: The PIL Image object or None if conversion + fails. + """ + try: + # Decode the base64 string to get the image data + image_data = base64.b64decode(base64_string) + # Create a memory buffer for the image data + image_buffer = BytesIO(image_data) + # Open the image using the PIL library + image = Image.open(image_buffer) + return image + except Exception as e: + print(f"An error occurred while converting base64 to image: {e}") + return None + + +def _image_path_to_base64(image_path: str) -> str: + r"""Converts the file path of an image to a Base64 encoded string. + + Args: + image_path (str): The path to the image file. + + Returns: + str: A Base64 encoded string representing the content of the image + file. + """ + try: + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode('utf-8') + except Exception as e: + print(f"An error occurred while converting image path to base64: {e}") + return "" + + +def _image_to_base64(image: Image.Image) -> str: + r"""Converts an image into a base64-encoded string. This function takes + an image object as input, encodes the image into a PNG format base64 + string, and returns it. If the encoding process encounters an error, + it prints the error message and returns `None`. + + Args: + image (Image.Image): The image object to be encoded, supports any + image format that can be saved in PNG format. + + Returns: + str: A base64-encoded string of the image. + """ + try: + with BytesIO() as buffered_image: + image.save(buffered_image, format="PNG") + buffered_image.seek(0) + image_bytes = buffered_image.read() + base64_str = base64.b64encode(image_bytes).decode('utf-8') + return base64_str + except Exception as e: + print(f"An error occurred: {e}") + return "" + + +@export_to_toolkit +def get_dalle_img(prompt: str, image_dir: str = "img") -> str: + r"""Generate an image using OpenAI's DALL-E model. The generated image + is saved to the specified directory. + + Args: + prompt (str): The text prompt based on which the image is + generated. + image_dir (str): The directory to save the generated image. + Defaults to 'img'. + + Returns: + str: The path to the saved image. """ - def base64_to_image(self, base64_string: str) -> Optional[Image.Image]: - r"""Converts a base64 encoded string into a PIL Image object. + dalle_client = OpenAI() + response = dalle_client.images.generate( + model="dall-e-3", + prompt=prompt, + size="1024x1792", + quality="standard", + n=1, # NOTE: now dall-e-3 only supports n=1 + response_format="b64_json", + ) + image_b64 = response.data[0].b64_json + image = _base64_to_image(image_b64) # type: ignore[arg-type] - Args: - base64_string (str): The base64 encoded string of the image. + if image is None: + raise ValueError("Failed to convert base64 string to image.") - Returns: - Optional[Image.Image]: The PIL Image object or None if conversion - fails. - """ - try: - # Decode the base64 string to get the image data - image_data = base64.b64decode(base64_string) - # Create a memory buffer for the image data - image_buffer = BytesIO(image_data) - # Open the image using the PIL library - image = Image.open(image_buffer) - return image - except Exception as e: - print(f"An error occurred while converting base64 to image: {e}") - return None - - def image_path_to_base64(self, image_path: str) -> str: - r"""Converts the file path of an image to a Base64 encoded string. - - Args: - image_path (str): The path to the image file. + os.makedirs(image_dir, exist_ok=True) + image_path = os.path.join(image_dir, f"{uuid.uuid4()}.png") + image.save(image_path) - Returns: - str: A Base64 encoded string representing the content of the image - file. - """ - try: - with open(image_path, "rb") as image_file: - return base64.b64encode(image_file.read()).decode('utf-8') - except Exception as e: - print( - f"An error occurred while converting image path to base64: {e}" - ) - return "" - - def image_to_base64(self, image: Image.Image) -> str: - r"""Converts an image into a base64-encoded string. - - This function takes an image object as input, encodes the image into a - PNG format base64 string, and returns it. - If the encoding process encounters an error, it prints the error - message and returns None. - - Args: - image: The image object to be encoded, supports any image format - that can be saved in PNG format. + return image_path - Returns: - str: A base64-encoded string of the image. - """ - try: - with BytesIO() as buffered_image: - image.save(buffered_image, format="PNG") - buffered_image.seek(0) - image_bytes = buffered_image.read() - base64_str = base64.b64encode(image_bytes).decode('utf-8') - return base64_str - except Exception as e: - print(f"An error occurred: {e}") - return "" - - def get_dalle_img(self, prompt: str, image_dir: str = "img") -> str: - r"""Generate an image using OpenAI's DALL-E model. - The generated image is saved to the specified directory. - - Args: - prompt (str): The text prompt based on which the image is - generated. - image_dir (str): The directory to save the generated image. - Defaults to 'img'. - Returns: - str: The path to the saved image. - """ +class DalleToolkit(BaseToolkit): + r"""A class representing a toolkit for image generation using OpenAI's. + This class provides methods handle image generation using OpenAI's DALL-E. + """ - dalle_client = OpenAI() - response = dalle_client.images.generate( - model="dall-e-3", - prompt=prompt, - size="1024x1792", - quality="standard", - n=1, # NOTE: now dall-e-3 only supports n=1 - response_format="b64_json", - ) - image_b64 = response.data[0].b64_json - image = self.base64_to_image(image_b64) # type: ignore[arg-type] - - if image is None: - raise ValueError("Failed to convert base64 string to image.") - - os.makedirs(image_dir, exist_ok=True) - image_path = os.path.join(image_dir, f"{uuid.uuid4()}.png") - image.save(image_path) - - return image_path - - def get_tools(self) -> List[OpenAIFunction]: - r"""Returns a list of OpenAIFunction objects representing the + def get_tools(self) -> List[FunctionTool]: + r"""Returns a list of FunctionTool objects representing the functions in the toolkit. Returns: - List[OpenAIFunction]: A list of OpenAIFunction objects + List[FunctionTool]: A list of FunctionTool objects representing the functions in the toolkit. """ - return [OpenAIFunction(self.get_dalle_img)] + return DALLE_FUNCS -DALLE_FUNCS: List[OpenAIFunction] = DalleToolkit().get_tools() +DALLE_FUNCS = [ + FunctionTool(func=get_dalle_img, name_prefix=DalleToolkit.__name__) +] diff --git a/camel/toolkits/openai_function.py b/camel/toolkits/function_tool.py similarity index 93% rename from camel/toolkits/openai_function.py rename to camel/toolkits/function_tool.py index e2b35a8843..b8c7d0d0c6 100644 --- a/camel/toolkits/openai_function.py +++ b/camel/toolkits/function_tool.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +import warnings from inspect import Parameter, signature from typing import Any, Callable, Dict, Mapping, Optional, Tuple @@ -142,7 +143,7 @@ def _create_mol(name, field): return openai_tool_schema -class OpenAIFunction: +class FunctionTool: r"""An abstraction of a function that OpenAI chat models can call. See https://platform.openai.com/docs/api-reference/chat/create. @@ -161,11 +162,13 @@ def __init__( self, func: Callable, openai_tool_schema: Optional[Dict[str, Any]] = None, + name_prefix: Optional[str] = None, ) -> None: self.func = func self.openai_tool_schema = openai_tool_schema or get_openai_tool_schema( func ) + self.name_prefix = name_prefix @staticmethod def validate_openai_tool_schema( @@ -387,3 +390,37 @@ def parameters(self, value: Dict[str, Any]) -> None: except SchemaError as e: raise e self.openai_tool_schema["function"]["parameters"]["properties"] = value + + def __str__(self) -> str: + return ( + self.name_prefix + '.' + self.func.__name__ + if self.name_prefix + else self.func.__name__ + ) + + def __repr__(self) -> str: + return ( + self.name_prefix + '.' + self.func.__name__ + if self.name_prefix + else self.func.__name__ + ) + + +warnings.simplefilter('always', DeprecationWarning) + + +# Alias for backwards compatibility +class OpenAIFunction(FunctionTool): + def __init__(self, *args, **kwargs): + PURPLE = '\033[95m' + RESET = '\033[0m' + + def purple_warning(msg): + warnings.warn( + PURPLE + msg + RESET, DeprecationWarning, stacklevel=2 + ) + + purple_warning( + "OpenAIFunction is deprecated, please use FunctionTool instead." + ) + super().__init__(*args, **kwargs) diff --git a/camel/toolkits/github_toolkit.py b/camel/toolkits/github_toolkit.py index 141d7706b6..4fa01f6f7e 100644 --- a/camel/toolkits/github_toolkit.py +++ b/camel/toolkits/github_toolkit.py @@ -18,10 +18,30 @@ from pydantic import BaseModel +from camel.toolkits.base import BaseToolkit from camel.utils import dependencies_required +from camel.utils.commons import export_to_toolkit -from .base import BaseToolkit -from .openai_function import OpenAIFunction + +def get_github_access_token() -> str: + r"""Retrieve the GitHub access token from environment variables. + + Returns: + str: A string containing the GitHub access token. + + Raises: + ValueError: If the API key or secret is not found in the environment + variables. + """ + # Get `GITHUB_ACCESS_TOKEN` here: https://github.com/settings/tokens + github_token = os.environ.get("GITHUB_ACCESS_TOKEN") + + if not github_token: + raise ValueError( + "`GITHUB_ACCESS_TOKEN` not found in environment variables. Get it " + "here: `https://github.com/settings/tokens`." + ) + return github_token class GithubIssue(BaseModel): @@ -69,18 +89,16 @@ class GithubPullRequestDiff(BaseModel): patch: str def __str__(self) -> str: - r"""Returns a string representation of this diff.""" return f"Filename: {self.filename}\nPatch: {self.patch}" class GithubPullRequest(BaseModel): - r"""Represents a pull request on Github. + r"""Represents a pull request on GitHub. Attributes: title (str): The title of the GitHub pull request. body (str): The body/content of the GitHub pull request. - diffs (List[GithubPullRequestDiff]): A list of diffs for the pull - request. + diffs (List[GithubPullRequestDiff]): A list of diffs for the PR. """ title: str @@ -88,7 +106,6 @@ class GithubPullRequest(BaseModel): diffs: List[GithubPullRequestDiff] def __str__(self) -> str: - r"""Returns a string representation of the pull request.""" diff_summaries = '\n'.join(str(diff) for diff in self.diffs) return ( f"Title: {self.title}\n" @@ -99,16 +116,15 @@ def __str__(self) -> str: class GithubToolkit(BaseToolkit): r"""A class representing a toolkit for interacting with GitHub - repositories. - - This class provides methods for retrieving open issues, retrieving - specific issues, and creating pull requests in a GitHub repository. + repositories. This class provides methods for retrieving open issues, + retrieving specific issues, and creating pull requests in a GitHub + repository. Args: repo_name (str): The name of the GitHub repository. - access_token (str, optional): The access token to authenticate with - GitHub. If not provided, it will be obtained using the - `get_github_access_token` method. + access_token (Optional[str], optional): The access token to + authenticate with GitHub. If not provided, it will be obtained + using the `get_github_access_token` method. (default: :obj:`None`) """ @dependencies_required('github') @@ -119,58 +135,26 @@ def __init__( Args: repo_name (str): The name of the GitHub repository. - access_token (str, optional): The access token to authenticate - with GitHub. If not provided, it will be obtained using the - `get_github_access_token` method. + access_token (Optional[str], optional): The access token to + authenticate with GitHub. If not provided, it will be obtained + using the `get_github_access_token` method. + (default: :obj:`None`) """ if access_token is None: - access_token = self.get_github_access_token() + access_token = get_github_access_token() from github import Auth, Github self.github = Github(auth=Auth.Token(access_token)) self.repo = self.github.get_repo(repo_name) - def get_tools(self) -> List[OpenAIFunction]: - r"""Returns a list of OpenAIFunction objects representing the - functions in the toolkit. - - Returns: - List[OpenAIFunction]: A list of OpenAIFunction objects - representing the functions in the toolkit. - """ - return [ - OpenAIFunction(self.retrieve_issue_list), - OpenAIFunction(self.retrieve_issue), - OpenAIFunction(self.create_pull_request), - OpenAIFunction(self.retrieve_pull_requests), - ] - - def get_github_access_token(self) -> str: - r"""Retrieve the GitHub access token from environment variables. - - Returns: - str: A string containing the GitHub access token. - - Raises: - ValueError: If the API key or secret is not found in the - environment variables. - """ - # Get `GITHUB_ACCESS_TOKEN` here: https://github.com/settings/tokens - GITHUB_ACCESS_TOKEN = os.environ.get("GITHUB_ACCESS_TOKEN") - - if not GITHUB_ACCESS_TOKEN: - raise ValueError( - "`GITHUB_ACCESS_TOKEN` not found in environment variables. Get" - " it here: `https://github.com/settings/tokens`." - ) - return GITHUB_ACCESS_TOKEN - + @export_to_toolkit def retrieve_issue_list(self) -> List[GithubIssue]: r"""Retrieve a list of open issues from the repository. Returns: - A list of GithubIssue objects representing the open issues. + List[GithubIssue]: A list of GithubIssue objects representing + the open issues. """ issues = self.repo.get_issues(state='open') return [ @@ -187,17 +171,16 @@ def retrieve_issue_list(self) -> List[GithubIssue]: if not issue.pull_request ] + @export_to_toolkit def retrieve_issue(self, issue_number: int) -> Optional[str]: - r"""Retrieves an issue from a GitHub repository. - - This function retrieves an issue from a specified repository using the - issue number. + r"""Retrieves an issue from a GitHub repository. This function + retrieves an issue from a specified repository using the issue number. Args: issue_number (int): The number of the issue to retrieve. Returns: - str: A formatted report of the retrieved issue. + Optional[str]: A formatted report of the retrieved issue. """ issues = self.retrieve_issue_list() for issue in issues: @@ -205,6 +188,7 @@ def retrieve_issue(self, issue_number: int) -> Optional[str]: return str(issue) return None + @export_to_toolkit def retrieve_pull_requests( self, days: int, state: str, max_prs: int ) -> List[str]: @@ -212,10 +196,8 @@ def retrieve_pull_requests( The summary will be provided for the last specified number of days. Args: - days (int): The number of days to retrieve merged pull requests - for. - state (str): A specific state of PRs to retrieve. Can be open or - closed. + days (int): Number of days to retrieve merged PRs for. + state (str): The state of PRs to retrieve. Can be open or closed. max_prs (int): The maximum number of PRs to retrieve. Returns: @@ -247,6 +229,7 @@ def retrieve_pull_requests( merged_prs.append(str(pr_details)) return merged_prs + @export_to_toolkit def create_pull_request( self, file_path: str, @@ -255,11 +238,10 @@ def create_pull_request( body: str, branch_name: str, ) -> str: - r"""Creates a pull request. - - This function creates a pull request in specified repository, which - updates a file in the specific path with new content. The pull request - description contains information about the issue title and number. + r"""Creates a pull request. This function creates a pull request in + specified repository, which updates a file in the specific path with + new content. The pull request description contains information about + the issue title and number. Args: file_path (str): The path of the file to be updated in the diff --git a/camel/toolkits/google_maps_toolkit.py b/camel/toolkits/google_maps_toolkit.py index 9b15fba5e0..29aaa34a1a 100644 --- a/camel/toolkits/google_maps_toolkit.py +++ b/camel/toolkits/google_maps_toolkit.py @@ -16,8 +16,8 @@ from typing import Any, Callable, List, Optional, Union from camel.toolkits.base import BaseToolkit -from camel.toolkits.openai_function import OpenAIFunction from camel.utils import dependencies_required +from camel.utils.commons import export_to_toolkit def handle_googlemaps_exceptions( @@ -115,6 +115,7 @@ def __init__(self) -> None: self.gmaps = googlemaps.Client(key=api_key) + @export_to_toolkit @handle_googlemaps_exceptions def get_address_description( self, @@ -195,6 +196,7 @@ def get_address_description( return description + @export_to_toolkit @handle_googlemaps_exceptions def get_elevation(self, lat: float, lng: float) -> str: r"""Retrieves elevation data for a given latitude and longitude. @@ -237,6 +239,7 @@ def get_elevation(self, lat: float, lng: float) -> str: return description + @export_to_toolkit @handle_googlemaps_exceptions def get_timezone(self, lat: float, lng: float) -> str: r"""Retrieves timezone information for a given latitude and longitude. @@ -286,17 +289,3 @@ def get_timezone(self, lat: float, lng: float) -> str: ) return description - - def get_tools(self) -> List[OpenAIFunction]: - r"""Returns a list of OpenAIFunction objects representing the - functions in the toolkit. - - Returns: - List[OpenAIFunction]: A list of OpenAIFunction objects - representing the functions in the toolkit. - """ - return [ - OpenAIFunction(self.get_address_description), - OpenAIFunction(self.get_elevation), - OpenAIFunction(self.get_timezone), - ] diff --git a/camel/toolkits/linkedin_toolkit.py b/camel/toolkits/linkedin_toolkit.py index 1993849ce0..24ee407a92 100644 --- a/camel/toolkits/linkedin_toolkit.py +++ b/camel/toolkits/linkedin_toolkit.py @@ -15,13 +15,12 @@ import json import os from http import HTTPStatus -from typing import List import requests -from camel.toolkits import OpenAIFunction from camel.toolkits.base import BaseToolkit from camel.utils import handle_http_error +from camel.utils.commons import export_to_toolkit LINKEDIN_POST_LIMIT = 1300 @@ -36,6 +35,7 @@ class LinkedInToolkit(BaseToolkit): def __init__(self): self._access_token = self._get_access_token() + @export_to_toolkit def create_post(self, text: str) -> dict: r"""Creates a post on LinkedIn for the authenticated user. @@ -86,6 +86,7 @@ def create_post(self, text: str) -> dict: f"Response: {response.text}" ) + @export_to_toolkit def delete_post(self, post_id: str) -> str: r"""Deletes a LinkedIn post with the specified ID for an authorized user. @@ -136,6 +137,7 @@ def delete_post(self, post_id: str) -> str: return f"Post deleted successfully. Post ID: {post_id}." + @export_to_toolkit def get_profile(self, include_id: bool = False) -> dict: r"""Retrieves the authenticated user's LinkedIn profile info. @@ -194,20 +196,6 @@ def get_profile(self, include_id: bool = False) -> dict: return profile_report - def get_tools(self) -> List[OpenAIFunction]: - r"""Returns a list of OpenAIFunction objects representing the - functions in the toolkit. - - Returns: - List[OpenAIFunction]: A list of OpenAIFunction objects - representing the functions in the toolkit. - """ - return [ - OpenAIFunction(self.create_post), - OpenAIFunction(self.delete_post), - OpenAIFunction(self.get_profile), - ] - def _get_access_token(self) -> str: r"""Fetches the access token required for making LinkedIn API requests. diff --git a/camel/toolkits/math_toolkit.py b/camel/toolkits/math_toolkit.py index 398f0070d2..af7f51c142 100644 --- a/camel/toolkits/math_toolkit.py +++ b/camel/toolkits/math_toolkit.py @@ -15,65 +15,70 @@ from typing import List from camel.toolkits.base import BaseToolkit -from camel.toolkits.openai_function import OpenAIFunction +from camel.toolkits.function_tool import FunctionTool +from camel.utils.commons import export_to_toolkit -class MathToolkit(BaseToolkit): - r"""A class representing a toolkit for mathematical operations. +@export_to_toolkit +def add(a: int, b: int) -> int: + r"""Adds two numbers. + + Args: + a (int): The first number to be added. + b (int): The second number to be added. - This class provides methods for basic mathematical operations such as - addition, subtraction, and multiplication. + Returns: + int: The sum of the two numbers. """ + return a + b - def add(self, a: int, b: int) -> int: - r"""Adds two numbers. - Args: - a (int): The first number to be added. - b (int): The second number to be added. +@export_to_toolkit +def sub(a: int, b: int) -> int: + r"""Do subtraction between two numbers. - Returns: - integer: The sum of the two numbers. - """ - return a + b + Args: + a (int): The minuend in subtraction. + b (int): The subtrahend in subtraction. - def sub(self, a: int, b: int) -> int: - r"""Do subtraction between two numbers. + Returns: + int: The result of subtracting :paramref:`b` from :paramref:`a`. + """ + return a - b - Args: - a (int): The minuend in subtraction. - b (int): The subtrahend in subtraction. - Returns: - integer: The result of subtracting :obj:`b` from :obj:`a`. - """ - return a - b +@export_to_toolkit +def mul(a: int, b: int) -> int: + r"""Multiplies two integers. - def mul(self, a: int, b: int) -> int: - r"""Multiplies two integers. + Args: + a (int): The multiplier in the multiplication. + b (int): The multiplicand in the multiplication. - Args: - a (int): The multiplier in the multiplication. - b (int): The multiplicand in the multiplication. + Returns: + int: The product of the two numbers. + """ + return a * b - Returns: - integer: The product of the two numbers. - """ - return a * b - def get_tools(self) -> List[OpenAIFunction]: - r"""Returns a list of OpenAIFunction objects representing the +class MathToolkit(BaseToolkit): + r"""A class representing a toolkit for mathematical operations. This + class provides methods for basic mathematical operations such as addition, + subtraction, and multiplication. + """ + + def get_tools(self) -> List[FunctionTool]: + r"""Returns a list of FunctionTool objects representing the functions in the toolkit. Returns: - List[OpenAIFunction]: A list of OpenAIFunction objects + List[FunctionTool]: A list of FunctionTool objects representing the functions in the toolkit. """ - return [ - OpenAIFunction(self.add), - OpenAIFunction(self.sub), - OpenAIFunction(self.mul), - ] + return MATH_FUNCS -MATH_FUNCS: List[OpenAIFunction] = MathToolkit().get_tools() +MATH_FUNCS = [ + FunctionTool(func=math_func, name_prefix=MathToolkit.__name__) + for math_func in (add, sub, mul) +] diff --git a/camel/toolkits/open_api_toolkit.py b/camel/toolkits/open_api_toolkit.py index 10f90f4bba..bca387c311 100644 --- a/camel/toolkits/open_api_toolkit.py +++ b/camel/toolkits/open_api_toolkit.py @@ -17,7 +17,7 @@ import requests -from camel.toolkits import OpenAIFunction, openapi_security_config +from camel.toolkits import FunctionTool, openapi_security_config from camel.types import OpenAPIName @@ -526,12 +526,12 @@ def generate_apinames_filepaths(self) -> List[Tuple[str, str]]: apinames_filepaths.append((api_name.value, file_path)) return apinames_filepaths - def get_tools(self) -> List[OpenAIFunction]: - r"""Returns a list of OpenAIFunction objects representing the + def get_tools(self) -> List[FunctionTool]: + r"""Returns a list of FunctionTool objects representing the functions in the toolkit. Returns: - List[OpenAIFunction]: A list of OpenAIFunction objects + List[FunctionTool]: A list of FunctionTool objects representing the functions in the toolkit. """ apinames_filepaths = self.generate_apinames_filepaths() @@ -539,6 +539,6 @@ def get_tools(self) -> List[OpenAIFunction]: self.apinames_filepaths_to_funs_schemas(apinames_filepaths) ) return [ - OpenAIFunction(a_func, a_schema) + FunctionTool(a_func, a_schema) for a_func, a_schema in zip(all_funcs_lst, all_schemas_lst) ] diff --git a/camel/toolkits/reddit_toolkit.py b/camel/toolkits/reddit_toolkit.py index 5a9e7d5daf..7795dfdb19 100644 --- a/camel/toolkits/reddit_toolkit.py +++ b/camel/toolkits/reddit_toolkit.py @@ -18,8 +18,8 @@ from requests.exceptions import RequestException -from camel.toolkits import OpenAIFunction from camel.toolkits.base import BaseToolkit +from camel.utils.commons import export_to_toolkit class RedditToolkit(BaseToolkit): @@ -85,6 +85,7 @@ def _retry_request(self, func, *args, **kwargs): else: raise + @export_to_toolkit def collect_top_posts( self, subreddit_name: str, @@ -132,6 +133,7 @@ def collect_top_posts( return data + @export_to_toolkit def perform_sentiment_analysis( self, data: List[Dict[str, Any]] ) -> List[Dict[str, Any]]: @@ -156,6 +158,7 @@ def perform_sentiment_analysis( return data + @export_to_toolkit def track_keyword_discussions( self, subreddits: List[str], @@ -218,17 +221,3 @@ def track_keyword_discussions( if sentiment_analysis: data = self.perform_sentiment_analysis(data) return data - - def get_tools(self) -> List[OpenAIFunction]: - r"""Returns a list of OpenAIFunction objects representing the - functions in the toolkit. - - Returns: - List[OpenAIFunction]: A list of OpenAIFunction objects for the - toolkit methods. - """ - return [ - OpenAIFunction(self.collect_top_posts), - OpenAIFunction(self.perform_sentiment_analysis), - OpenAIFunction(self.track_keyword_discussions), - ] diff --git a/camel/toolkits/retrieval_toolkit.py b/camel/toolkits/retrieval_toolkit.py index 370b2c66a6..87483c7847 100644 --- a/camel/toolkits/retrieval_toolkit.py +++ b/camel/toolkits/retrieval_toolkit.py @@ -14,10 +14,10 @@ from typing import List, Optional, Union from camel.retrievers import AutoRetriever -from camel.toolkits import OpenAIFunction from camel.toolkits.base import BaseToolkit from camel.types import StorageType from camel.utils import Constants +from camel.utils.commons import export_to_toolkit class RetrievalToolkit(BaseToolkit): @@ -34,6 +34,7 @@ def __init__(self, auto_retriever: Optional[AutoRetriever] = None) -> None: storage_type=StorageType.QDRANT, ) + @export_to_toolkit def information_retrieval( self, query: str, @@ -74,15 +75,3 @@ def information_retrieval( similarity_threshold=similarity_threshold, ) return str(retrieved_info) - - def get_tools(self) -> List[OpenAIFunction]: - r"""Returns a list of OpenAIFunction objects representing the - functions in the toolkit. - - Returns: - List[OpenAIFunction]: A list of OpenAIFunction objects - representing the functions in the toolkit. - """ - return [ - OpenAIFunction(self.information_retrieval), - ] diff --git a/camel/toolkits/search_toolkit.py b/camel/toolkits/search_toolkit.py index bac9423cb2..17bfa7152b 100644 --- a/camel/toolkits/search_toolkit.py +++ b/camel/toolkits/search_toolkit.py @@ -15,312 +15,319 @@ from typing import Any, Dict, List from camel.toolkits.base import BaseToolkit -from camel.toolkits.openai_function import OpenAIFunction +from camel.toolkits.function_tool import FunctionTool +from camel.utils.commons import export_to_toolkit -class SearchToolkit(BaseToolkit): - r"""A class representing a toolkit for web search. +@export_to_toolkit +def search_wiki(entity: str) -> str: + r"""Search the entity in WikiPedia and return the summary of the + required page, containing factual information about + the given entity. - This class provides methods for searching information on the web using - search engines like Google, DuckDuckGo, Wikipedia and Wolfram Alpha. + Args: + entity (str): The entity to be searched. + + Returns: + str: The search result. If the page corresponding to the entity + exists, return the summary of this entity in a string. """ + try: + import wikipedia + except ImportError: + raise ImportError( + "Please install `wikipedia` first. You can install it " + "by running `pip install wikipedia`." + ) - def search_wiki(self, entity: str) -> str: - r"""Search the entity in WikiPedia and return the summary of the - required page, containing factual information about - the given entity. + result: str - Args: - entity (str): The entity to be searched. + try: + result = wikipedia.summary(entity, sentences=5, auto_suggest=False) + except wikipedia.exceptions.DisambiguationError as e: + result = wikipedia.summary( + e.options[0], sentences=5, auto_suggest=False + ) + except wikipedia.exceptions.PageError: + result = ( + "There is no page in Wikipedia corresponding to entity " + f"{entity}, please specify another word to describe the" + " entity to be searched." + ) + except wikipedia.exceptions.WikipediaException as e: + result = f"An exception occurred during the search: {e}" - Returns: - str: The search result. If the page corresponding to the entity - exists, return the summary of this entity in a string. - """ - try: - import wikipedia - except ImportError: - raise ImportError( - "Please install `wikipedia` first. You can install it " - "by running `pip install wikipedia`." - ) + return result - result: str - try: - result = wikipedia.summary(entity, sentences=5, auto_suggest=False) - except wikipedia.exceptions.DisambiguationError as e: - result = wikipedia.summary( - e.options[0], sentences=5, auto_suggest=False - ) - except wikipedia.exceptions.PageError: - result = ( - "There is no page in Wikipedia corresponding to entity " - f"{entity}, please specify another word to describe the" - " entity to be searched." - ) - except wikipedia.exceptions.WikipediaException as e: - result = f"An exception occurred during the search: {e}" - - return result - - def search_duckduckgo( - self, query: str, source: str = "text", max_results: int = 5 - ) -> List[Dict[str, Any]]: - r"""Use DuckDuckGo search engine to search information for - the given query. - - This function queries the DuckDuckGo API for related topics to - the given search term. The results are formatted into a list of - dictionaries, each representing a search result. - - Args: - query (str): The query to be searched. - source (str): The type of information to query (e.g., "text", - "images", "videos"). Defaults to "text". - max_results (int): Max number of results, defaults to `5`. +@export_to_toolkit +def search_duckduckgo( + query: str, source: str = "text", max_results: int = 5 +) -> List[Dict[str, Any]]: + r"""Use DuckDuckGo search engine to search information for + the given query. - Returns: - List[Dict[str, Any]]: A list of dictionaries where each dictionary - represents a search result. - """ - from duckduckgo_search import DDGS - from requests.exceptions import RequestException - - ddgs = DDGS() - responses: List[Dict[str, Any]] = [] - - if source == "text": - try: - results = ddgs.text(keywords=query, max_results=max_results) - except RequestException as e: - # Handle specific exceptions or general request exceptions - responses.append({"error": f"duckduckgo search failed.{e}"}) - - # Iterate over results found - for i, result in enumerate(results, start=1): - # Creating a response object with a similar structure - response = { - "result_id": i, - "title": result["title"], - "description": result["body"], - "url": result["href"], - } - responses.append(response) + This function queries the DuckDuckGo API for related topics to + the given search term. The results are formatted into a list of + dictionaries, each representing a search result. - elif source == "images": - try: - results = ddgs.images(keywords=query, max_results=max_results) - except RequestException as e: - # Handle specific exceptions or general request exceptions - responses.append({"error": f"duckduckgo search failed.{e}"}) + Args: + query (str): The query to be searched. + source (str): The type of information to query (e.g., "text", + "images", "videos"). Defaults to "text". + max_results (int): Max number of results, defaults to `5`. - # Iterate over results found - for i, result in enumerate(results, start=1): - # Creating a response object with a similar structure - response = { - "result_id": i, - "title": result["title"], - "image": result["image"], - "url": result["url"], - "source": result["source"], - } - responses.append(response) + Returns: + List[Dict[str, Any]]: A list of dictionaries where each dictionary + represents a search result. + """ + from duckduckgo_search import DDGS + from requests.exceptions import RequestException - elif source == "videos": - try: - results = ddgs.videos(keywords=query, max_results=max_results) - except RequestException as e: - # Handle specific exceptions or general request exceptions - responses.append({"error": f"duckduckgo search failed.{e}"}) + ddgs = DDGS() + responses: List[Dict[str, Any]] = [] - # Iterate over results found - for i, result in enumerate(results, start=1): - # Creating a response object with a similar structure + if source == "text": + try: + results = ddgs.text(keywords=query, max_results=max_results) + except RequestException as e: + # Handle specific exceptions or general request exceptions + responses.append({"error": f"duckduckgo search failed.{e}"}) + + # Iterate over results found + for i, result in enumerate(results, start=1): + # Creating a response object with a similar structure + response = { + "result_id": i, + "title": result["title"], + "description": result["body"], + "url": result["href"], + } + responses.append(response) + + elif source == "images": + try: + results = ddgs.images(keywords=query, max_results=max_results) + except RequestException as e: + # Handle specific exceptions or general request exceptions + responses.append({"error": f"duckduckgo search failed.{e}"}) + + # Iterate over results found + for i, result in enumerate(results, start=1): + # Creating a response object with a similar structure + response = { + "result_id": i, + "title": result["title"], + "image": result["image"], + "url": result["url"], + "source": result["source"], + } + responses.append(response) + + elif source == "videos": + try: + results = ddgs.videos(keywords=query, max_results=max_results) + except RequestException as e: + # Handle specific exceptions or general request exceptions + responses.append({"error": f"duckduckgo search failed.{e}"}) + + # Iterate over results found + for i, result in enumerate(results, start=1): + # Creating a response object with a similar structure + response = { + "result_id": i, + "title": result["title"], + "description": result["description"], + "embed_url": result["embed_url"], + "publisher": result["publisher"], + "duration": result["duration"], + "published": result["published"], + } + responses.append(response) + + # If no answer found, return an empty list + return responses + + +@export_to_toolkit +def search_google( + query: str, num_result_pages: int = 5 +) -> List[Dict[str, Any]]: + r"""Use Google search engine to search information for the given query. + + Args: + query (str): The query to be searched. + num_result_pages (int): The number of result pages to retrieve. + + Returns: + List[Dict[str, Any]]: A list of dictionaries where each dictionary + represents a website. + Each dictionary contains the following keys: + - 'result_id': A number in order. + - 'title': The title of the website. + - 'description': A brief description of the website. + - 'long_description': More detail of the website. + - 'url': The URL of the website. + + Example: + { + 'result_id': 1, + 'title': 'OpenAI', + 'description': 'An organization focused on ensuring that + artificial general intelligence benefits all of humanity.', + 'long_description': 'OpenAI is a non-profit artificial + intelligence research company. Our goal is to advance + digital intelligence in the way that is most likely to + benefit humanity as a whole', + 'url': 'https://www.openai.com' + } + title, description, url of a website. + """ + import requests + + # https://developers.google.com/custom-search/v1/overview + GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") + # https://cse.google.com/cse/all + SEARCH_ENGINE_ID = os.getenv("SEARCH_ENGINE_ID") + + # Using the first page + start_page_idx = 1 + # Different language may get different result + search_language = "en" + # How many pages to return + num_result_pages = num_result_pages + # Constructing the URL + # Doc: https://developers.google.com/custom-search/v1/using_rest + url = ( + f"https://www.googleapis.com/customsearch/v1?" + f"key={GOOGLE_API_KEY}&cx={SEARCH_ENGINE_ID}&q={query}&start=" + f"{start_page_idx}&lr={search_language}&num={num_result_pages}" + ) + + responses = [] + # Fetch the results given the URL + try: + # Make the get + result = requests.get(url) + data = result.json() + + # Get the result items + if "items" in data: + search_items = data.get("items") + + # Iterate over 10 results found + for i, search_item in enumerate(search_items, start=1): + if "og:description" in search_item["pagemap"]["metatags"][0]: + long_description = search_item["pagemap"]["metatags"][0][ + "og:description" + ] + else: + long_description = "N/A" + # Get the page title + title = search_item.get("title") + # Page snippet + snippet = search_item.get("snippet") + + # Extract the page url + link = search_item.get("link") response = { "result_id": i, - "title": result["title"], - "description": result["description"], - "embed_url": result["embed_url"], - "publisher": result["publisher"], - "duration": result["duration"], - "published": result["published"], + "title": title, + "description": snippet, + "long_description": long_description, + "url": link, } responses.append(response) + else: + responses.append({"error": "google search failed."}) - # If no answer found, return an empty list - return responses + except requests.RequestException: + # Handle specific exceptions or general request exceptions + responses.append({"error": "google search failed."}) + # If no answer found, return an empty list + return responses - def search_google( - self, query: str, num_result_pages: int = 5 - ) -> List[Dict[str, Any]]: - r"""Use Google search engine to search information for the given query. - Args: - query (str): The query to be searched. - num_result_pages (int): The number of result pages to retrieve. +@export_to_toolkit +def query_wolfram_alpha(query: str, is_detailed: bool) -> str: + r"""Queries Wolfram|Alpha and returns the result. Wolfram|Alpha is an + answer engine developed by Wolfram Research. It is offered as an online + service that answers factual queries by computing answers from + externally sourced data. - Returns: - List[Dict[str, Any]]: A list of dictionaries where each dictionary - represents a website. - Each dictionary contains the following keys: - - 'result_id': A number in order. - - 'title': The title of the website. - - 'description': A brief description of the website. - - 'long_description': More detail of the website. - - 'url': The URL of the website. - - Example: - { - 'result_id': 1, - 'title': 'OpenAI', - 'description': 'An organization focused on ensuring that - artificial general intelligence benefits all of humanity.', - 'long_description': 'OpenAI is a non-profit artificial - intelligence research company. Our goal is to advance - digital intelligence in the way that is most likely to - benefit humanity as a whole', - 'url': 'https://www.openai.com' - } - title, description, url of a website. - """ - import requests - - # https://developers.google.com/custom-search/v1/overview - GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") - # https://cse.google.com/cse/all - SEARCH_ENGINE_ID = os.getenv("SEARCH_ENGINE_ID") - - # Using the first page - start_page_idx = 1 - # Different language may get different result - search_language = "en" - # How many pages to return - num_result_pages = num_result_pages - # Constructing the URL - # Doc: https://developers.google.com/custom-search/v1/using_rest - url = ( - f"https://www.googleapis.com/customsearch/v1?" - f"key={GOOGLE_API_KEY}&cx={SEARCH_ENGINE_ID}&q={query}&start=" - f"{start_page_idx}&lr={search_language}&num={num_result_pages}" + Args: + query (str): The query to send to Wolfram Alpha. + is_detailed (bool): Whether to include additional details in the + result. + + Returns: + str: The result from Wolfram Alpha, formatted as a string. + """ + try: + import wolframalpha + except ImportError: + raise ImportError( + "Please install `wolframalpha` first. You can install it by" + " running `pip install wolframalpha`." ) - responses = [] - # Fetch the results given the URL - try: - # Make the get - result = requests.get(url) - data = result.json() + WOLFRAMALPHA_APP_ID = os.environ.get('WOLFRAMALPHA_APP_ID') + if not WOLFRAMALPHA_APP_ID: + raise ValueError( + "`WOLFRAMALPHA_APP_ID` not found in environment " + "variables. Get `WOLFRAMALPHA_APP_ID` here: " + "`https://products.wolframalpha.com/api/`." + ) - # Get the result items - if "items" in data: - search_items = data.get("items") + try: + client = wolframalpha.Client(WOLFRAMALPHA_APP_ID) + res = client.query(query) + assumption = next(res.pods).text or "No assumption made." + answer = next(res.results).text or "No answer found." + except Exception as e: + if isinstance(e, StopIteration): + return "Wolfram Alpha wasn't able to answer it" + else: + error_message = f"Wolfram Alpha wasn't able to answer it" f"{e!s}." + return error_message - # Iterate over 10 results found - for i, search_item in enumerate(search_items, start=1): - if ( - "og:description" - in search_item["pagemap"]["metatags"][0] - ): - long_description = search_item["pagemap"]["metatags"][ - 0 - ]["og:description"] - else: - long_description = "N/A" - # Get the page title - title = search_item.get("title") - # Page snippet - snippet = search_item.get("snippet") - - # Extract the page url - link = search_item.get("link") - response = { - "result_id": i, - "title": title, - "description": snippet, - "long_description": long_description, - "url": link, - } - responses.append(response) - else: - responses.append({"error": "google search failed."}) - - except requests.RequestException: - # Handle specific exceptions or general request exceptions - responses.append({"error": "google search failed."}) - # If no answer found, return an empty list - return responses + result = f"Assumption:\n{assumption}\n\nAnswer:\n{answer}" - def query_wolfram_alpha(self, query: str, is_detailed: bool) -> str: - r"""Queries Wolfram|Alpha and returns the result. Wolfram|Alpha is an - answer engine developed by Wolfram Research. It is offered as an online - service that answers factual queries by computing answers from - externally sourced data. + # Add additional details in the result + if is_detailed: + result += '\n' + for pod in res.pods: + result += '\n' + pod['@title'] + ':\n' + for sub in pod.subpods: + result += (sub.plaintext or "None") + '\n' - Args: - query (str): The query to send to Wolfram Alpha. - is_detailed (bool): Whether to include additional details in the - result. + return result.rstrip() # Remove trailing whitespace - Returns: - str: The result from Wolfram Alpha, formatted as a string. - """ - try: - import wolframalpha - except ImportError: - raise ImportError( - "Please install `wolframalpha` first. You can install it by" - " running `pip install wolframalpha`." - ) - - WOLFRAMALPHA_APP_ID = os.environ.get('WOLFRAMALPHA_APP_ID') - if not WOLFRAMALPHA_APP_ID: - raise ValueError( - "`WOLFRAMALPHA_APP_ID` not found in environment " - "variables. Get `WOLFRAMALPHA_APP_ID` here: " - "`https://products.wolframalpha.com/api/`." - ) - try: - client = wolframalpha.Client(WOLFRAMALPHA_APP_ID) - res = client.query(query) - assumption = next(res.pods).text or "No assumption made." - answer = next(res.results).text or "No answer found." - except Exception as e: - if isinstance(e, StopIteration): - return "Wolfram Alpha wasn't able to answer it" - else: - error_message = ( - f"Wolfram Alpha wasn't able to answer it" f"{e!s}." - ) - return error_message - - result = f"Assumption:\n{assumption}\n\nAnswer:\n{answer}" - - # Add additional details in the result - if is_detailed: - result += '\n' - for pod in res.pods: - result += '\n' + pod['@title'] + ':\n' - for sub in pod.subpods: - result += (sub.plaintext or "None") + '\n' - - return result.rstrip() # Remove trailing whitespace - - def get_tools(self) -> List[OpenAIFunction]: - r"""Returns a list of OpenAIFunction objects representing the +class SearchToolkit(BaseToolkit): + r"""A class representing a toolkit for web search. + + This class provides methods for searching information on the web using + search engines like Google, DuckDuckGo, Wikipedia and Wolfram Alpha. + """ + + def get_tools(self) -> List[FunctionTool]: + r"""Returns a list of FunctionTool objects representing the functions in the toolkit. Returns: - List[OpenAIFunction]: A list of OpenAIFunction objects + List[FunctionTool]: A list of FunctionTool objects representing the functions in the toolkit. """ - return [ - OpenAIFunction(self.search_wiki), - OpenAIFunction(self.search_google), - OpenAIFunction(self.search_duckduckgo), - OpenAIFunction(self.query_wolfram_alpha), - ] - - -SEARCH_FUNCS: List[OpenAIFunction] = SearchToolkit().get_tools() + return SEARCH_FUNCS + + +SEARCH_FUNCS = [ + FunctionTool(func=func, name_prefix=SearchToolkit.__name__) + for func in ( + search_wiki, + search_duckduckgo, + search_google, + query_wolfram_alpha, + ) +] diff --git a/camel/toolkits/slack_toolkit.py b/camel/toolkits/slack_toolkit.py index f386397d40..d2b24fff5e 100644 --- a/camel/toolkits/slack_toolkit.py +++ b/camel/toolkits/slack_toolkit.py @@ -17,16 +17,16 @@ import json import logging import os -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional from camel.toolkits.base import BaseToolkit +from camel.utils.commons import export_to_toolkit if TYPE_CHECKING: from ssl import SSLContext from slack_sdk import WebClient -from camel.toolkits import OpenAIFunction logger = logging.getLogger(__name__) @@ -81,6 +81,7 @@ def _login_slack( logger.info("Slack login successful.") return client + @export_to_toolkit def create_slack_channel( self, name: str, is_private: Optional[bool] = True ) -> str: @@ -112,6 +113,7 @@ def create_slack_channel( except SlackApiError as e: return f"Error creating conversation: {e.response['error']}" + @export_to_toolkit def join_slack_channel(self, channel_id: str) -> str: r"""Joins an existing Slack channel. @@ -135,6 +137,7 @@ def join_slack_channel(self, channel_id: str) -> str: except SlackApiError as e: return f"Error creating conversation: {e.response['error']}" + @export_to_toolkit def leave_slack_channel(self, channel_id: str) -> str: r"""Leaves an existing Slack channel. @@ -158,6 +161,7 @@ def leave_slack_channel(self, channel_id: str) -> str: except SlackApiError as e: return f"Error creating conversation: {e.response['error']}" + @export_to_toolkit def get_slack_channel_information(self) -> str: r"""Retrieve Slack channels and return relevant information in JSON format. @@ -191,6 +195,7 @@ def get_slack_channel_information(self) -> str: except SlackApiError as e: return f"Error creating conversation: {e.response['error']}" + @export_to_toolkit def get_slack_channel_message(self, channel_id: str) -> str: r"""Retrieve messages from a Slack channel. @@ -220,6 +225,7 @@ def get_slack_channel_message(self, channel_id: str) -> str: except SlackApiError as e: return f"Error retrieving messages: {e.response['error']}" + @export_to_toolkit def send_slack_message( self, message: str, @@ -257,6 +263,7 @@ def send_slack_message( except SlackApiError as e: return f"Error creating conversation: {e.response['error']}" + @export_to_toolkit def delete_slack_message( self, time_stamp: str, @@ -285,21 +292,3 @@ def delete_slack_message( return str(response) except SlackApiError as e: return f"Error creating conversation: {e.response['error']}" - - def get_tools(self) -> List[OpenAIFunction]: - r"""Returns a list of OpenAIFunction objects representing the - functions in the toolkit. - - Returns: - List[OpenAIFunction]: A list of OpenAIFunction objects - representing the functions in the toolkit. - """ - return [ - OpenAIFunction(self.create_slack_channel), - OpenAIFunction(self.join_slack_channel), - OpenAIFunction(self.leave_slack_channel), - OpenAIFunction(self.get_slack_channel_information), - OpenAIFunction(self.get_slack_channel_message), - OpenAIFunction(self.send_slack_message), - OpenAIFunction(self.delete_slack_message), - ] diff --git a/camel/toolkits/toolkits_manager.py b/camel/toolkits/toolkits_manager.py new file mode 100644 index 0000000000..80a7e61a07 --- /dev/null +++ b/camel/toolkits/toolkits_manager.py @@ -0,0 +1,367 @@ +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +import importlib +import inspect +import pkgutil +from typing import Callable, List, Optional, Union +from warnings import warn + +from camel.toolkits.base import BaseToolkit +from camel.toolkits.function_tool import FunctionTool + + +class ToolkitManager: + r""" + A class representing a manager for dynamically loading and accessing + toolkits. + + The ToolkitManager loads all callable toolkits from the `camel.toolkits` + package and provides methods to list, retrieve, and search them as + FunctionTool objects. + """ + + _instance = None + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super(ToolkitManager, cls).__new__( + cls, *args, **kwargs + ) + return cls._instance + + def __init__(self): + r""" + Initializes the ToolkitManager and loads all available toolkits. + """ + if not hasattr(self, '_initialized'): + self._initialized = True + self.toolkits = {} + self.toolkit_classes = {} + self.toolkit_class_methods = {} + self._load_toolkits() + self._load_toolkit_class_and_methods() + + def _load_toolkits(self): + r""" + Dynamically loads all toolkit functions from the `camel.toolkits` + package. + + For each module in the package, it checks for functions decorated with + `@export_to_toolkit`, which adds the `_is_exported` attribute. + """ + package = importlib.import_module('camel.toolkits') + for _, module_name, _ in pkgutil.iter_modules(package.__path__): + module = importlib.import_module(f'camel.toolkits.{module_name}') + + base_toolkit_class_name = None + for _, cls in inspect.getmembers(module, inspect.isclass): + if issubclass(cls, BaseToolkit) and cls is not BaseToolkit: + base_toolkit_class_name = cls.__name__ + break + + prefix = ( + base_toolkit_class_name + '.' + if base_toolkit_class_name + else '' + ) + + for name, func in inspect.getmembers(module, inspect.isfunction): + if ( + hasattr(func, '_is_exported') + and func.__module__ == module.__name__ + ): + self.toolkits[f"{prefix}{name}"] = func + + def _load_toolkit_class_and_methods(self): + r""" + Dynamically loads all classes and their exported methods from the + `camel.toolkits` package. + + For each module in the package, it identifies public classes. For each + class, it collects only those methods that are decorated with + `@export_to_toolkit`, which adds the `_is_exported` attribute. + + In addition to class methods, it also collects standalone functions + from the target module that are decorated with `@export_to_toolkit` + and adds them to the corresponding toolkit class. This allows + including both class methods and standalone functions under the + same toolkit class. + """ + package = importlib.import_module('camel.toolkits') + + for _, module_name, _ in pkgutil.iter_modules(package.__path__): + module = importlib.import_module(f'camel.toolkits.{module_name}') + toolkit_class_name = None + for name, cls in inspect.getmembers(module, inspect.isclass): + if cls.__module__ == module.__name__ and issubclass( + cls, BaseToolkit + ): + self.toolkit_classes[name] = cls + + self.toolkit_class_methods[name] = { + method_name: method + for method_name, method in inspect.getmembers( + cls, inspect.isfunction + ) + if callable(method) and hasattr(method, '_is_exported') + } + toolkit_class_name = name + if toolkit_class_name: + for name, func in inspect.getmembers( + module, inspect.isfunction + ): + if ( + hasattr(func, '_is_exported') + and func.__module__ == module.__name__ + ): + self.toolkit_class_methods[toolkit_class_name][ + name + ] = func + + def register_tool( + self, + toolkit_obj: Union[Callable, object, List[Union[Callable, object]]], + ) -> List[FunctionTool]: + r""" + Registers a toolkit function or instance and adds it to the toolkits + list. If the input is a list, it processes each element in the list. + + Parameters: + toolkit_obj (Union[Callable, object, List[Union[Callable, + object]]]): The toolkit function(s) or instance(s) to be + registered. + + Returns: + FunctionTools (List[FunctionTool]): A list of FunctionTool + instances. + """ + res_openai_functions = [] + + # If the input is a list, process each element + if isinstance(toolkit_obj, list): + for obj in toolkit_obj: + res_openai_functions.extend(self._register_single_tool(obj)) + else: + res_openai_functions = self._register_single_tool(toolkit_obj) + + return res_openai_functions + + def _register_single_tool( + self, toolkit_obj: Union[Callable, object] + ) -> List[FunctionTool]: + r""" + Helper function to register a single toolkit function or instance. + + Parameters: + toolkit_obj (Union[Callable, object]): The toolkit function or + instance to be processed. + + Returns: + FunctionTools (tuple[List[FunctionTool]): A list of FunctionTool + instances. + """ + res_openai_functions = [] + if callable(toolkit_obj): + if self.add_toolkit_from_function(toolkit_obj): + res_openai_functions.append(FunctionTool(toolkit_obj)) + else: + if self.add_toolkit_from_instance( + **{toolkit_obj.__class__.__name__: toolkit_obj} + ) and hasattr(toolkit_obj, 'get_tools'): + res_openai_functions.extend(toolkit_obj.get_tools()) + return res_openai_functions + + def add_toolkit_from_function(self, toolkit_func: Callable) -> bool: + r""" + Adds a toolkit function to the toolkits list. + + Parameters: + toolkit_func (Callable): The toolkit function to be added. + + Returns: + Bool: True if the addition was successful, False otherwise. + """ + if not callable(toolkit_func): + warn("Provided argument is not a callable function.") + + func_name = toolkit_func.__name__ + + if not func_name: + warn("Function must have a valid name.") + + self.toolkits[func_name] = toolkit_func + + return True + + def add_toolkit_from_instance(self, **kwargs) -> bool: + r""" + Add a toolkit class instance to the tool list. + Custom instance names are supported here. + + Parameters: + kwargs: The toolkit class instance to be added. Keyword arguments + where each value is expected to be an instance of BaseToolkit. + + Returns: + Bool: True if the addition was successful, False otherwise. + """ + is_method_added = False + + for toolkit_instance_name, toolkit_instance in kwargs.items(): + if isinstance(toolkit_instance, BaseToolkit): + for attr_name in dir(toolkit_instance): + attr = getattr(toolkit_instance, attr_name) + + if callable(attr) and hasattr(attr, '_is_exported'): + method_name = f"{toolkit_instance_name}.{attr_name}" + self.toolkits[method_name] = attr + is_method_added = True + + else: + warn( + f"Failed to add {toolkit_instance_name}: " + + "Not an instance of BaseToolkit." + ) + + return is_method_added + + def list_toolkits(self) -> List[str]: + r""" + Lists the names of all available toolkits. + + Returns: + List[str]: A list of all toolkit function names available for use. + """ + return list(self.toolkits.keys()) + + def list_toolkit_classes(self) -> List[str]: + r""" + Lists the names of all available toolkit classes along with their + methods. + + Returns: + List[str]: A list of strings in the format 'ClassName: method1, + method2, ...'. + """ + result = [] + + for class_name, methods in self.toolkit_class_methods.items(): + if methods: + methods_str = ', '.join(methods) + + formatted_string = f"{class_name}: {methods_str}" + + result.append(formatted_string) + + return result + + def get_toolkit(self, name: str) -> Optional[FunctionTool]: + r""" + Retrieves the specified toolkit as an FunctionTool object. + + Args: + name (str): The name of the toolkit function to retrieve. + + Returns: + FunctionTool (optional): The toolkit wrapped as an FunctionTool. + """ + toolkit = self.toolkits.get(name) + if toolkit: + return FunctionTool(func=toolkit, name_prefix=name.split('.')[0]) + return None + + def get_toolkits(self, names: list[str]) -> list[FunctionTool]: + r""" + Retrieves the specified toolkit as an FunctionTool object. + + Args: + name (str): The name of the toolkit function to retrieve. + + Returns: + FunctionTools (list): The toolkits wrapped as FunctionTools. + """ + toolkits: list[FunctionTool] = [] + for name in names: + current_toolkit = self.toolkits.get(name) + if current_toolkit: + toolkits.append( + FunctionTool( + func=current_toolkit, name_prefix=name.split('.')[0] + ) + ) + return toolkits + + def get_toolkit_class( + self, class_name: str + ) -> Optional[type[BaseToolkit]]: + r""" + Retrieves the specified toolkit class. + + Args: + class_name (str): The name of the toolkit class to retrieve. + + Returns: + BaseToolkit(optional): The toolkit class object if found. + """ + toolkit_class = self.toolkit_classes.get(class_name) + if toolkit_class: + return toolkit_class + return None + + def _default_search_algorithm( + self, keyword: str, description: str + ) -> bool: + r""" + Default search algorithm. + + Args: + keyword (str): The keyword to search for. + description (str): The description to search within. + + Returns: + bool: True if a match is found based on similarity, False + otherwise. + """ + return keyword.lower() in description.lower() + + def search_toolkits( + self, + keyword: str, + algorithm: Optional[Callable[[str, str], bool]] = None, + ) -> Optional[List[FunctionTool]]: + r""" + Searches for toolkits based on a keyword in their descriptions using + the provided search algorithm. + + Args: + keyword (str): The keyword to search for in toolkit descriptions. + algorithm (Callable[[str, str], bool], optional): A custom + search algorithm function that accepts the keyword and + description and returns a boolean. Defaults to fuzzy matching. + + Returns: + FunctionTools (list): A list of toolkit names whose descriptions + match the keyword, otherwise an error message. + """ + if algorithm is None: + algorithm = self._default_search_algorithm + + matching_toolkits_names = [] + for name, func in self.toolkits.items(): + openai_func = FunctionTool(func) + description = openai_func.get_function_description() + if algorithm(keyword, description) or algorithm(keyword, name): + matching_toolkits_names.append(name) + + return self.get_toolkits(matching_toolkits_names) diff --git a/camel/toolkits/twitter_toolkit.py b/camel/toolkits/twitter_toolkit.py index b5a744f89d..fd26761f92 100644 --- a/camel/toolkits/twitter_toolkit.py +++ b/camel/toolkits/twitter_toolkit.py @@ -19,8 +19,8 @@ import requests -from camel.toolkits import OpenAIFunction from camel.toolkits.base import BaseToolkit +from camel.utils.commons import export_to_toolkit TWEET_TEXT_LIMIT = 280 @@ -32,6 +32,7 @@ class TwitterToolkit(BaseToolkit): getting the authenticated user's profile information. """ + @export_to_toolkit def create_tweet( self, *, @@ -162,6 +163,7 @@ def create_tweet( return response_str + @export_to_toolkit def delete_tweet(self, tweet_id: str) -> str: r"""Deletes a tweet with the specified ID for an authorized user. @@ -226,6 +228,7 @@ def delete_tweet(self, tweet_id: str) -> str: ) return response_str + @export_to_toolkit def get_my_user_profile(self) -> str: r"""Retrieves and formats the authenticated user's Twitter profile info. @@ -365,20 +368,6 @@ def get_my_user_profile(self) -> str: return user_report - def get_tools(self) -> List[OpenAIFunction]: - r"""Returns a list of OpenAIFunction objects representing the - functions in the toolkit. - - Returns: - List[OpenAIFunction]: A list of OpenAIFunction objects - representing the functions in the toolkit. - """ - return [ - OpenAIFunction(self.create_tweet), - OpenAIFunction(self.delete_tweet), - OpenAIFunction(self.get_my_user_profile), - ] - def _get_twitter_api_key(self) -> Tuple[str, str]: r"""Retrieve the Twitter API key and secret from environment variables. diff --git a/camel/toolkits/weather_toolkit.py b/camel/toolkits/weather_toolkit.py index 72980da5a5..9ddf57502c 100644 --- a/camel/toolkits/weather_toolkit.py +++ b/camel/toolkits/weather_toolkit.py @@ -15,7 +15,142 @@ from typing import List, Literal from camel.toolkits.base import BaseToolkit -from camel.toolkits.openai_function import OpenAIFunction +from camel.toolkits.function_tool import FunctionTool +from camel.utils.commons import export_to_toolkit + + +def get_openweathermap_api_key() -> str: + r"""Retrieve the OpenWeatherMap API key from environment variables. + + Returns: + str: The OpenWeatherMap API key. + + Raises: + ValueError: If the API key is not found in the environment + variables. + """ + # Get `OPENWEATHERMAP_API_KEY` here: https://openweathermap.org + api_key = os.environ.get('OPENWEATHERMAP_API_KEY') + if not api_key: + raise ValueError( + "`OPENWEATHERMAP_API_KEY` not found in environment " + "variables. Get `OPENWEATHERMAP_API_KEY` here: " + "`https://openweathermap.org`." + ) + return api_key + + +@export_to_toolkit +def get_weather_data( + city: str, + temp_units: Literal['kelvin', 'celsius', 'fahrenheit'] = 'kelvin', + wind_units: Literal[ + 'meters_sec', 'miles_hour', 'knots', 'beaufort' + ] = 'meters_sec', + visibility_units: Literal['meters', 'miles'] = 'meters', + time_units: Literal['unix', 'iso', 'date'] = 'unix', +) -> str: + r"""Fetch and return a comprehensive weather report for a given city + as a string. The report includes current weather conditions, + temperature, wind details, visibility, and sunrise/sunset times, + all formatted as a readable string. + + The function interacts with the OpenWeatherMap API to + retrieve the data. + + Args: + city (str): The name of the city for which the weather information + is desired. Format "City, CountryCode" (e.g., "Paris, FR" + for Paris, France). If the country code is not provided, + the API will search for the city in all countries, which + may yield incorrect results if multiple cities with the + same name exist. + temp_units (Literal['kelvin', 'celsius', 'fahrenheit']): Units for + temperature. (default: :obj:`kelvin`) + wind_units + (Literal['meters_sec', 'miles_hour', 'knots', 'beaufort']): + Units for wind speed. (default: :obj:`meters_sec`) + visibility_units (Literal['meters', 'miles']): Units for visibility + distance. (default: :obj:`meters`) + time_units (Literal['unix', 'iso', 'date']): Format for sunrise and + sunset times. (default: :obj:`unix`) + + Returns: + str: A string containing the fetched weather data, formatted in a + readable manner. If an error occurs, a message indicating the + error will be returned instead. + + Example of return string: + "Weather in Paris, FR: 15°C, feels like 13°C. Max temp: 17°C, + Min temp : 12°C. + Wind: 5 m/s at 270 degrees. Visibility: 10 kilometers. + Sunrise at 05:46:05 (UTC), Sunset at 18:42:20 (UTC)." + + Note: + Please ensure that the API key is valid and has permissions + to access the weather data. + """ + # NOTE: This tool may not work as expected since the input arguments + # like `time_units` should be enum types which are not supported yet. + + try: + import pyowm + except ImportError: + raise ImportError( + "Please install `pyowm` first. You can install it by running " + "`pip install pyowm`." + ) + + api_key = get_openweathermap_api_key() + owm = pyowm.OWM(api_key) + mgr = owm.weather_manager() + + try: + observation = mgr.weather_at_place(city) + weather = observation.weather + + # Temperature + temperature = weather.temperature(temp_units) + + # Wind + wind_data = observation.weather.wind(unit=wind_units) + wind_speed = wind_data.get('speed') + # 'N/A' if the degree is not available + wind_deg = wind_data.get('deg', 'N/A') + + # Visibility + visibility_distance = observation.weather.visibility_distance + visibility = ( + str(visibility_distance) + if visibility_units == 'meters' + else str(observation.weather.visibility(unit='miles')) + ) + + # Sunrise and Sunset + sunrise_time = str(weather.sunrise_time(timeformat=time_units)) + sunset_time = str(weather.sunset_time(timeformat=time_units)) + + # Compile all the weather details into a report string + weather_report = ( + f"Weather in {city}: " + f"{temperature['temp']}°{temp_units.title()}, " + f"feels like " + f"{temperature['feels_like']}°{temp_units.title()}. " + f"Max temp: {temperature['temp_max']}°{temp_units.title()}, " + f"Min temp: {temperature['temp_min']}°{temp_units.title()}. " + f"Wind: {wind_speed} {wind_units} at {wind_deg} degrees. " + f"Visibility: {visibility} {visibility_units}. " + f"Sunrise at {sunrise_time}, Sunset at {sunset_time}." + ) + + return weather_report + + except Exception as e: + error_message = ( + f"An error occurred while fetching weather data for {city}: " + f"{e!s}." + ) + return error_message class WeatherToolkit(BaseToolkit): @@ -25,149 +160,17 @@ class WeatherToolkit(BaseToolkit): using the OpenWeatherMap API. """ - def get_openweathermap_api_key(self) -> str: - r"""Retrieve the OpenWeatherMap API key from environment variables. - - Returns: - str: The OpenWeatherMap API key. - - Raises: - ValueError: If the API key is not found in the environment - variables. - """ - # Get `OPENWEATHERMAP_API_KEY` here: https://openweathermap.org - OPENWEATHERMAP_API_KEY = os.environ.get('OPENWEATHERMAP_API_KEY') - if not OPENWEATHERMAP_API_KEY: - raise ValueError( - "`OPENWEATHERMAP_API_KEY` not found in environment " - "variables. Get `OPENWEATHERMAP_API_KEY` here: " - "`https://openweathermap.org`." - ) - return OPENWEATHERMAP_API_KEY - - def get_weather_data( - self, - city: str, - temp_units: Literal['kelvin', 'celsius', 'fahrenheit'] = 'kelvin', - wind_units: Literal[ - 'meters_sec', 'miles_hour', 'knots', 'beaufort' - ] = 'meters_sec', - visibility_units: Literal['meters', 'miles'] = 'meters', - time_units: Literal['unix', 'iso', 'date'] = 'unix', - ) -> str: - r"""Fetch and return a comprehensive weather report for a given city - as a string. The report includes current weather conditions, - temperature, wind details, visibility, and sunrise/sunset times, - all formatted as a readable string. - - The function interacts with the OpenWeatherMap API to - retrieve the data. - - Args: - city (str): The name of the city for which the weather information - is desired. Format "City, CountryCode" (e.g., "Paris, FR" - for Paris, France). If the country code is not provided, - the API will search for the city in all countries, which - may yield incorrect results if multiple cities with the - same name exist. - temp_units (Literal['kelvin', 'celsius', 'fahrenheit']): Units for - temperature. (default: :obj:`kelvin`) - wind_units - (Literal['meters_sec', 'miles_hour', 'knots', 'beaufort']): - Units for wind speed. (default: :obj:`meters_sec`) - visibility_units (Literal['meters', 'miles']): Units for visibility - distance. (default: :obj:`meters`) - time_units (Literal['unix', 'iso', 'date']): Format for sunrise and - sunset times. (default: :obj:`unix`) - - Returns: - str: A string containing the fetched weather data, formatted in a - readable manner. If an error occurs, a message indicating the - error will be returned instead. - - Example of return string: - "Weather in Paris, FR: 15°C, feels like 13°C. Max temp: 17°C, - Min temp : 12°C. - Wind: 5 m/s at 270 degrees. Visibility: 10 kilometers. - Sunrise at 05:46:05 (UTC), Sunset at 18:42:20 (UTC)." - - Note: - Please ensure that the API key is valid and has permissions - to access the weather data. - """ - # NOTE: This tool may not work as expected since the input arguments - # like `time_units` should be enum types which are not supported yet. - - try: - import pyowm - except ImportError: - raise ImportError( - "Please install `pyowm` first. You can install it by running " - "`pip install pyowm`." - ) - - OPENWEATHERMAP_API_KEY = self.get_openweathermap_api_key() - owm = pyowm.OWM(OPENWEATHERMAP_API_KEY) - mgr = owm.weather_manager() - - try: - observation = mgr.weather_at_place(city) - weather = observation.weather - - # Temperature - temperature = weather.temperature(temp_units) - - # Wind - wind_data = observation.weather.wind(unit=wind_units) - wind_speed = wind_data.get('speed') - # 'N/A' if the degree is not available - wind_deg = wind_data.get('deg', 'N/A') - - # Visibility - visibility_distance = observation.weather.visibility_distance - visibility = ( - str(visibility_distance) - if visibility_units == 'meters' - else str(observation.weather.visibility(unit='miles')) - ) - - # Sunrise and Sunset - sunrise_time = str(weather.sunrise_time(timeformat=time_units)) - sunset_time = str(weather.sunset_time(timeformat=time_units)) - - # Compile all the weather details into a report string - weather_report = ( - f"Weather in {city}: " - f"{temperature['temp']}°{temp_units.title()}, " - f"feels like " - f"{temperature['feels_like']}°{temp_units.title()}. " - f"Max temp: {temperature['temp_max']}°{temp_units.title()}, " - f"Min temp: {temperature['temp_min']}°{temp_units.title()}. " - f"Wind: {wind_speed} {wind_units} at {wind_deg} degrees. " - f"Visibility: {visibility} {visibility_units}. " - f"Sunrise at {sunrise_time}, Sunset at {sunset_time}." - ) - - return weather_report - - except Exception as e: - error_message = ( - f"An error occurred while fetching weather data for {city}: " - f"{e!s}." - ) - return error_message - - def get_tools(self) -> List[OpenAIFunction]: - r"""Returns a list of OpenAIFunction objects representing the + def get_tools(self) -> List[FunctionTool]: + r"""Returns a list of FunctionTool objects representing the functions in the toolkit. Returns: - List[OpenAIFunction]: A list of OpenAIFunction objects + List[FunctionTool]: A list of FunctionTool objects representing the functions in the toolkit. """ - return [ - OpenAIFunction(self.get_weather_data), - ] + return WEATHER_FUNCS -WEATHER_FUNCS: List[OpenAIFunction] = WeatherToolkit().get_tools() +WEATHER_FUNCS = [ + FunctionTool(func=get_weather_data, name_prefix=WeatherToolkit.__name__) +] diff --git a/camel/utils/async_func.py b/camel/utils/async_func.py index 377bca4f5e..77caf4e9ec 100644 --- a/camel/utils/async_func.py +++ b/camel/utils/async_func.py @@ -14,20 +14,20 @@ import asyncio from copy import deepcopy -from camel.toolkits import OpenAIFunction +from camel.toolkits import FunctionTool -def sync_funcs_to_async(funcs: list[OpenAIFunction]) -> list[OpenAIFunction]: +def sync_funcs_to_async(funcs: list[FunctionTool]) -> list[FunctionTool]: r"""Convert a list of Python synchronous functions to Python asynchronous functions. Args: - funcs (list[OpenAIFunction]): List of Python synchronous - functions in the :obj:`OpenAIFunction` format. + funcs (list[FunctionTool]): List of Python synchronous + functions in the :obj:`FunctionTool` format. Returns: - list[OpenAIFunction]: List of Python asynchronous functions - in the :obj:`OpenAIFunction` format. + list[FunctionTool]: List of Python asynchronous functions + in the :obj:`FunctionTool` format. """ async_funcs = [] for func in funcs: @@ -37,6 +37,6 @@ def async_callable(*args, **kwargs): return asyncio.to_thread(sync_func, *args, **kwargs) # noqa: B023 async_funcs.append( - OpenAIFunction(async_callable, deepcopy(func.openai_tool_schema)) + FunctionTool(async_callable, deepcopy(func.openai_tool_schema)) ) return async_funcs diff --git a/camel/utils/commons.py b/camel/utils/commons.py index 943e398c7f..8cc64d8b35 100644 --- a/camel/utils/commons.py +++ b/camel/utils/commons.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +import functools import importlib import os import platform @@ -577,3 +578,13 @@ def handle_http_error(response: requests.Response) -> str: return "Too Many Requests. You have hit the rate limit." else: return "HTTP Error" + + +def export_to_toolkit(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + result = func(*args, **kwargs) + return result + + wrapper._is_exported = True + return wrapper diff --git a/docs/camel.toolkits.rst b/docs/camel.toolkits.rst index c5647e8fd6..d8c1347309 100644 --- a/docs/camel.toolkits.rst +++ b/docs/camel.toolkits.rst @@ -71,7 +71,7 @@ camel.toolkits.open\_api\_toolkit module camel.toolkits.openai\_function module -------------------------------------- -.. automodule:: camel.toolkits.openai_function +.. automodule:: camel.toolkits.function_tool :members: :undoc-members: :show-inheritance: diff --git a/docs/key_modules/tools.md b/docs/key_modules/tools.md index d1b14e0ceb..4018b7f9c3 100644 --- a/docs/key_modules/tools.md +++ b/docs/key_modules/tools.md @@ -14,7 +14,7 @@ To enhance your agents' capabilities with CAMEL tools, start by installing our a pip install 'camel-ai[tools]' ``` -In CAMEL, a tool is an `OpenAIFunction` that LLMs can call. +In CAMEL, a tool is an `FunctionTool` that LLMs can call. ### 2.1 How to Define Your Own Tool? @@ -22,7 +22,7 @@ In CAMEL, a tool is an `OpenAIFunction` that LLMs can call. Developers can create custom tools tailored to their agent’s specific needs: ```python -from camel.toolkits import OpenAIFunction +from camel.toolkits import FunctionTool def add(a: int, b: int) -> int: r"""Adds two numbers. @@ -36,7 +36,7 @@ def add(a: int, b: int) -> int: """ return a + b -add_tool = OpenAIFunction(add) +add_tool = FunctionTool(add) ``` ```python @@ -105,8 +105,8 @@ To utilize specific tools from the toolkits, you can implement code like the fol ```python from camel.toolkits import SearchToolkit -google_tool = OpenAIFunction(SearchToolkit().search_google) -wiki_tool = OpenAIFunction(SearchToolkit().search_wiki) +google_tool = FunctionTool(SearchToolkit().search_google) +wiki_tool = FunctionTool(SearchToolkit().search_wiki) ``` Here is a list of the available CAMEL tools and their descriptions: diff --git a/examples/function_call/github_examples.py b/examples/function_call/github_examples.py index 48ac5b5aa8..833b0c712c 100644 --- a/examples/function_call/github_examples.py +++ b/examples/function_call/github_examples.py @@ -19,7 +19,7 @@ from camel.configs import ChatGPTConfig from camel.messages import BaseMessage from camel.models import ModelFactory -from camel.toolkits import GithubToolkit, OpenAIFunction +from camel.toolkits import FunctionTool, GithubToolkit from camel.types import ModelPlatformType, ModelType from camel.utils import print_text_animated @@ -70,7 +70,7 @@ def write_weekly_pr_summary(repo_name, model=None): agent = ChatAgent( assistant_sys_msg, model=assistant_model, - tools=[OpenAIFunction(toolkit.retrieve_pull_requests)], + tools=[FunctionTool(toolkit.retrieve_pull_requests)], ) agent.reset() diff --git a/examples/tasks/conditional_function_calling.py b/examples/tasks/conditional_function_calling.py new file mode 100644 index 0000000000..775a95b8a3 --- /dev/null +++ b/examples/tasks/conditional_function_calling.py @@ -0,0 +1,84 @@ +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +from camel.tasks.task import ( + FunctionToolState, + FunctionToolTransition, + Task, + TaskManagerWithState, +) +from camel.toolkits.toolkits_manager import ToolkitManager + +if __name__ == "__main__": + # Define subtasks with descriptive content and unique IDs + tasks = [ + Task(content="Search for suitable phone", id="1"), + Task(content="Place phone order", id="2"), + Task(content="Make payment", id="3"), + ] + + # Define task states with specific tools from the ToolkitManager + states = [ + FunctionToolState( + name="SearchPhone", + tools_space=ToolkitManager().search_toolkits('search'), + ), + FunctionToolState( + name="PlaceOrder", + tools_space=ToolkitManager().search_toolkits('math'), + ), + FunctionToolState( + name="MakePayment", + tools_space=ToolkitManager().search_toolkits('img'), + ), + FunctionToolState(name="Done"), + ] + + # Define task state transitions with trigger, source, and destination + transitions = [ + FunctionToolTransition( + trigger=tasks[0], source=states[0], dest=states[1] + ), + FunctionToolTransition( + trigger=tasks[1], source=states[1], dest=states[2] + ), + FunctionToolTransition( + trigger=tasks[2], source=states[2], dest=states[3] + ), + ] + + # Initialize the Task Manager, starting with the initial state of + # "SearchPhone" + task_manager = TaskManagerWithState( + task=tasks[0], + initial_state=states[0], + states=states, + transitions=transitions, + ) + + # Task execution loop until reaching the "Done" state + while task_manager.current_state != states[-1]: + # Print the current state and available tools + print(f"Current State: {task_manager.current_state}") + print(f"Current Tools: {task_manager.get_current_tools()}") + + # Retrieve and execute the current task if available + current_task = task_manager.current_task + if current_task: + print(f"Executing Task: {current_task.content}") + # Simulate task execution and update task result + current_task.update_result('Subtask completed!') + print(f"Task Result: {current_task.result}") + + # Print updated state after task completion + print(f"Updated State: {task_manager.current_state}") diff --git a/examples/toolkits/toolkts_manager_example.py b/examples/toolkits/toolkts_manager_example.py new file mode 100644 index 0000000000..d6ffe0d8cc --- /dev/null +++ b/examples/toolkits/toolkts_manager_example.py @@ -0,0 +1,232 @@ +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +from typing import Type, cast + +from camel.toolkits import ToolkitManager +from camel.toolkits.github_toolkit import GithubToolkit + + +def pretty_print_list(title, items): + print(f"\n{'=' * 40}\n{title}:\n{'-' * 40}") + if not items: + print(" (No project)") + else: + for index, item in enumerate(items, start=1): + print(f" {index}. {item}") + print('=' * 40) + + +manager = ToolkitManager() + +toolkits = manager.list_toolkits() +toolkit_classes = manager.list_toolkit_classes() + +pretty_print_list("Function Toolkits", toolkits) +pretty_print_list("Class Toolkits", toolkit_classes) +""" +=============================================================================== +======================================== +Function Toolkits: +---------------------------------------- + 1. DalleToolkit.get_dalle_img + 2. MathToolkit.add + 3. MathToolkit.mul + 4. MathToolkit.sub + 5. SearchToolkit.query_wolfram_alpha + 6. SearchToolkit.search_duckduckgo + 7. SearchToolkit.search_google + 8. SearchToolkit.search_wiki + 9. WeatherToolkit.get_weather_data +======================================== + +======================================== +Class Toolkits: +---------------------------------------- + 1. CodeExecutionToolkit: execute_code + 2. DalleToolkit: get_dalle_img + 3. GithubToolkit: create_pull_request, retrieve_issue, retrieve_issue_list, + retrieve_pull_requests + 4. GoogleMapsToolkit: get_address_description, get_elevation, get_timezone + 5. LinkedInToolkit: create_post, delete_post, get_profile + 6. MathToolkit: add, mul, sub + 7. RedditToolkit: collect_top_posts, perform_sentiment_analysis, + track_keyword_discussions + 8. RetrievalToolkit: information_retrieval + 9. SearchToolkit: query_wolfram_alpha, search_duckduckgo, search_google, + search_wiki + 10. SlackToolkit: create_slack_channel, delete_slack_message, + get_slack_channel_information, get_slack_channel_message, + join_slack_channel, leave_slack_channel, send_slack_message + 11. TwitterToolkit: create_tweet, delete_tweet, get_my_user_profile + 12. WeatherToolkit: get_weather_data +======================================== +=============================================================================== +""" + +matching_toolkits_test = manager.search_toolkits('weather') +pretty_print_list("Matching Toolkit", matching_toolkits_test) + + +def strict_search_algorithm(keyword: str, description: str) -> bool: + return keyword.lower() in description.lower() + + +matching_toolkits_custom = manager.search_toolkits( + 'weather', algorithm=strict_search_algorithm +) +pretty_print_list( + "Custom Algorithm Matching Toolkit", matching_toolkits_custom +) +""" +=============================================================================== +======================================== +Matching Toolkit: +---------------------------------------- + 1. WeatherToolkit.get_weather_data +======================================== + +======================================== +Custom Algorithm Matching Toolkit: +---------------------------------------- + 1. WeatherToolkit.get_weather_data +======================================== +=============================================================================== +""" + +tool = manager.get_toolkit('WeatherToolkit.get_weather_data') +if tool: + print("\nFunction Description:") + print('-' * 40) + print(tool.get_function_description()) +""" +=============================================================================== +Function Description: +---------------------------------------- +Fetch and return a comprehensive weather report for a given city +as a string. The report includes current weather conditions, +temperature, wind details, visibility, and sunrise/sunset times, +all formatted as a readable string. + +The function interacts with the OpenWeatherMap API to +retrieve the data. + +=============================================================================== +""" + + +def div(a: int, b: int) -> float: + r"""Divides two numbers. + + Args: + a (int): The dividend in the division. + b (int): The divisor in the division. + + Returns: + float: The quotient of the division. + + Raises: + ValueError: If the divisor is zero. + """ + if b == 0: + raise ValueError("Division by zero is not allowed.") + + return a / b + + +camel_github_toolkit = GithubToolkit(repo_name='camel-ai/camel') + +added_tools = manager.register_tool( + [div, camel_github_toolkit] +) # manager.register_tool(div) is also supported. + +pretty_print_list("Added Tools", added_tools) +pretty_print_list("Available Toolkits for now", manager.list_toolkits()) +""" +=============================================================================== +======================================== +Added Tools: +---------------------------------------- + 1. div + 2. GithubToolkit.retrieve_issue_list + 3. GithubToolkit.retrieve_issue + 4. GithubToolkit.create_pull_request + 5. GithubToolkit.retrieve_pull_requests +======================================== +======================================== +Available Toolkits for now: +---------------------------------------- + 1. DalleToolkit.get_dalle_img + 2. MathToolkit.add + 3. MathToolkit.mul + 4. MathToolkit.sub + 5. SearchToolkit.query_wolfram_alpha + 6. SearchToolkit.search_duckduckgo + 7. SearchToolkit.search_google + 8. SearchToolkit.search_wiki + 9. WeatherToolkit.get_weather_data + 10. div + 11. GithubToolkit.create_pull_request + 12. GithubToolkit.retrieve_issue + 13. GithubToolkit.retrieve_issue_list + 14. GithubToolkit.retrieve_pull_requests +======================================== +=============================================================================== +""" + +crab_github_toolkit = GithubToolkit(repo_name='ZackYule/crab') + +# Custom instance names are supported here. +manager.add_toolkit_from_instance( + crab_github_toolkit=crab_github_toolkit, +) + +matching_tools_for_github = manager.search_toolkits('github') +pretty_print_list("Matching Tools for GitHub", matching_tools_for_github) +""" +=============================================================================== +======================================== +Matching Tools for GitHub: +---------------------------------------- + 1. GithubToolkit.create_pull_request + 2. GithubToolkit.retrieve_issue + 3. GithubToolkit.retrieve_issue_list + 4. GithubToolkit.retrieve_pull_requests + 5. crab_github_toolkit.create_pull_request + 6. crab_github_toolkit.retrieve_issue + 7. crab_github_toolkit.retrieve_issue_list + 8. crab_github_toolkit.retrieve_pull_requests +======================================== +=============================================================================== +""" + +toolkit_class = manager.get_toolkit_class('GithubToolkit') + +if toolkit_class: + toolkit_class = cast(Type[GithubToolkit], toolkit_class) + instance = toolkit_class(repo_name='ZackYule/crab') + pretty_print_list( + "Tools in the crab GitHub Toolkit instance", instance.get_tools() + ) +""" +=============================================================================== +======================================== +Tools in the crab GitHub Toolkit instance: +---------------------------------------- + 1. GithubToolkit.create_pull_request + 2. GithubToolkit.retrieve_issue + 3. GithubToolkit.retrieve_issue_list + 4. GithubToolkit.retrieve_pull_requests +======================================== +=============================================================================== +""" diff --git a/examples/workforce/hackathon_judges.py b/examples/workforce/hackathon_judges.py index ed993b2197..8aeb422164 100644 --- a/examples/workforce/hackathon_judges.py +++ b/examples/workforce/hackathon_judges.py @@ -18,7 +18,7 @@ from camel.messages import BaseMessage from camel.models import ModelFactory from camel.tasks import Task -from camel.toolkits import OpenAIFunction, SearchToolkit +from camel.toolkits import FunctionTool, SearchToolkit from camel.types import ModelPlatformType, ModelType from camel.workforce import Workforce @@ -76,8 +76,8 @@ def main(): search_toolkit = SearchToolkit() search_tools = [ - OpenAIFunction(search_toolkit.search_google), - OpenAIFunction(search_toolkit.search_duckduckgo), + FunctionTool(search_toolkit.search_google), + FunctionTool(search_toolkit.search_duckduckgo), ] researcher_model = ModelFactory.create( diff --git a/examples/workforce/multiple_single_agents.py b/examples/workforce/multiple_single_agents.py index 878bc64418..d18b504818 100644 --- a/examples/workforce/multiple_single_agents.py +++ b/examples/workforce/multiple_single_agents.py @@ -19,8 +19,8 @@ from camel.tasks.task import Task from camel.toolkits import ( WEATHER_FUNCS, + FunctionTool, GoogleMapsToolkit, - OpenAIFunction, SearchToolkit, ) from camel.types import ModelPlatformType, ModelType @@ -30,8 +30,8 @@ def main(): search_toolkit = SearchToolkit() search_tools = [ - OpenAIFunction(search_toolkit.search_google), - OpenAIFunction(search_toolkit.search_duckduckgo), + FunctionTool(search_toolkit.search_google), + FunctionTool(search_toolkit.search_duckduckgo), ] # Set up web searching agent diff --git a/test/agents/test_chat_agent.py b/test/agents/test_chat_agent.py index 01f6860cc8..3ca51fbc7b 100644 --- a/test/agents/test_chat_agent.py +++ b/test/agents/test_chat_agent.py @@ -33,8 +33,8 @@ from camel.models import ModelFactory from camel.terminators import ResponseWordsTerminator from camel.toolkits import ( + FunctionTool, MathToolkit, - OpenAIFunction, SearchToolkit, ) from camel.types import ( @@ -69,29 +69,37 @@ def test_chat_agent(model): dict(assistant_role="doctor"), role_tuple=("doctor", RoleType.ASSISTANT), ) - assistant = ChatAgent(system_msg, model=model) + assistant_with_sys_msg = ChatAgent(system_msg, model=model) + assistant_without_sys_msg = ChatAgent(model=model) - assert str(assistant) == ( + assert str(assistant_with_sys_msg) == ( "ChatAgent(doctor, " f"RoleType.ASSISTANT, {ModelType.GPT_4O_MINI})" ) + assert str(assistant_without_sys_msg) == ( + "ChatAgent(assistant, " f"RoleType.ASSISTANT, {ModelType.GPT_4O_MINI})" + ) + + for assistant in [assistant_with_sys_msg, assistant_without_sys_msg]: + assistant.reset() - assistant.reset() user_msg = BaseMessage( role_name="Patient", role_type=RoleType.USER, meta_dict=dict(), content="Hello!", ) - assistant_response = assistant.step(user_msg) - assert isinstance(assistant_response.msgs, list) - assert len(assistant_response.msgs) > 0 - assert isinstance(assistant_response.terminated, bool) - assert assistant_response.terminated is False - assert isinstance(assistant_response.info, dict) - assert assistant_response.info['id'] is not None + for assistant in [assistant_with_sys_msg, assistant_without_sys_msg]: + response = assistant.step(user_msg) + assert isinstance(response.msgs, list) + assert len(response.msgs) > 0 + assert isinstance(response.terminated, bool) + assert response.terminated is False + assert isinstance(response.info, dict) + assert response.info['id'] is not None +@pytest.mark.model_backend def test_chat_agent_stored_messages(): system_msg = BaseMessage( role_name="assistant", @@ -99,11 +107,16 @@ def test_chat_agent_stored_messages(): meta_dict=None, content="You are a help assistant.", ) - assistant = ChatAgent(system_msg) + + assistant_with_sys_msg = ChatAgent(system_msg) + assistant_without_sys_msg = ChatAgent() expected_context = [system_msg.to_openai_system_message()] - context, _ = assistant.memory.get_context() - assert context == expected_context + + context_with_sys_msg, _ = assistant_with_sys_msg.memory.get_context() + assert context_with_sys_msg == expected_context + context_without_sys_msg, _ = assistant_without_sys_msg.memory.get_context() + assert context_without_sys_msg == [] user_msg = BaseMessage( role_name="User", @@ -111,13 +124,22 @@ def test_chat_agent_stored_messages(): meta_dict=dict(), content="Tell me a joke.", ) - assistant.update_memory(user_msg, OpenAIBackendRole.USER) - expected_context = [ + + for assistant in [assistant_with_sys_msg, assistant_without_sys_msg]: + assistant.update_memory(user_msg, OpenAIBackendRole.USER) + + expected_context_with_sys_msg = [ system_msg.to_openai_system_message(), user_msg.to_openai_user_message(), ] - context, _ = assistant.memory.get_context() - assert context == expected_context + expected_context_without_sys_msg = [ + user_msg.to_openai_user_message(), + ] + + context_with_sys_msg, _ = assistant_with_sys_msg.memory.get_context() + assert context_with_sys_msg == expected_context_with_sys_msg + context_without_sys_msg, _ = assistant_without_sys_msg.memory.get_context() + assert context_without_sys_msg == expected_context_without_sys_msg @pytest.mark.model_backend @@ -273,17 +295,27 @@ def test_chat_agent_multiple_return_messages(n): meta_dict=None, content="You are a helpful assistant.", ) - assistant = ChatAgent(system_msg, model=model) - assistant.reset() + assistant_with_sys_msg = ChatAgent(system_msg, model=model) + assistant_without_sys_msg = ChatAgent(model=model) + + assistant_with_sys_msg.reset() + assistant_without_sys_msg.reset() + user_msg = BaseMessage( role_name="User", role_type=RoleType.USER, meta_dict=dict(), content="Tell me a joke.", ) - assistant_response = assistant.step(user_msg) - assert assistant_response.msgs is not None - assert len(assistant_response.msgs) == n + assistant_with_sys_msg_response = assistant_with_sys_msg.step(user_msg) + assistant_without_sys_msg_response = assistant_without_sys_msg.step( + user_msg + ) + + assert assistant_with_sys_msg_response.msgs is not None + assert len(assistant_with_sys_msg_response.msgs) == n + assert assistant_without_sys_msg_response.msgs is not None + assert len(assistant_without_sys_msg_response.msgs) == n @pytest.mark.model_backend @@ -396,21 +428,41 @@ def test_set_multiple_output_language(): meta_dict=None, content="You are a help assistant.", ) - agent = ChatAgent(system_message=system_message) + agent_with_sys_msg = ChatAgent(system_message=system_message) + agent_without_sys_msg = ChatAgent() # Verify that the length of the system message is kept constant even when # multiple set_output_language operations are called - agent.set_output_language("Chinese") - agent.set_output_language("English") - agent.set_output_language("French") - updated_system_message = BaseMessage( + agent_with_sys_msg.set_output_language("Chinese") + agent_with_sys_msg.set_output_language("English") + agent_with_sys_msg.set_output_language("French") + agent_without_sys_msg.set_output_language("Chinese") + agent_without_sys_msg.set_output_language("English") + agent_without_sys_msg.set_output_language("French") + + updated_system_message_with_content = BaseMessage( role_name="assistant", role_type=RoleType.ASSISTANT, meta_dict=None, content="You are a help assistant." "\nRegardless of the input language, you must output text in French.", ) - assert agent.system_message.content == updated_system_message.content + updated_system_message_without_content = BaseMessage( + role_name="assistant", + role_type=RoleType.ASSISTANT, + meta_dict=None, + content="\nRegardless of the input language, you must output text " + "in French.", + ) + + assert ( + agent_with_sys_msg.system_message.content + == updated_system_message_with_content.content + ) + assert ( + agent_without_sys_msg.system_message.content + == updated_system_message_without_content.content + ) @pytest.mark.model_backend @@ -554,7 +606,7 @@ async def async_sleep(second: int) -> int: agent = ChatAgent( system_message=system_message, model=model, - tools=[OpenAIFunction(async_sleep)], + tools=[FunctionTool(async_sleep)], ) assert len(agent.func_dict) == 1 diff --git a/test/toolkits/test_openai_function.py b/test/toolkits/test_openai_function.py index 3df1b32a72..256345001a 100644 --- a/test/toolkits/test_openai_function.py +++ b/test/toolkits/test_openai_function.py @@ -19,7 +19,7 @@ import pytest from jsonschema.exceptions import SchemaError -from camel.toolkits import OpenAIFunction, get_openai_tool_schema +from camel.toolkits import FunctionTool, get_openai_tool_schema from camel.types import RoleType from camel.utils import get_pydantic_major_version @@ -336,13 +336,13 @@ def add_with_wrong_doc(a: int, b: int) -> int: def test_correct_function(): - add = OpenAIFunction(add_with_doc) + add = FunctionTool(add_with_doc) add.set_function_name("add") assert add.get_openai_function_schema() == function_schema def test_function_without_doc(): - add = OpenAIFunction(add_without_doc) + add = FunctionTool(add_without_doc) add.set_function_name("add") with pytest.raises(Exception, match="miss function description"): _ = add.get_openai_function_schema() @@ -351,7 +351,7 @@ def test_function_without_doc(): def test_function_with_wrong_doc(): - add = OpenAIFunction(add_with_wrong_doc) + add = FunctionTool(add_with_wrong_doc) add.set_function_name("add") with pytest.raises(Exception, match="miss description of parameter \"b\""): _ = add.get_openai_function_schema() @@ -360,11 +360,11 @@ def test_function_with_wrong_doc(): def test_validate_openai_tool_schema_valid(): - OpenAIFunction.validate_openai_tool_schema(tool_schema) + FunctionTool.validate_openai_tool_schema(tool_schema) def test_get_set_openai_tool_schema(): - add = OpenAIFunction(add_with_doc) + add = FunctionTool(add_with_doc) assert add.get_openai_tool_schema() is not None new_schema = copy.deepcopy(tool_schema) new_schema["function"]["description"] = "New description" @@ -373,20 +373,20 @@ def test_get_set_openai_tool_schema(): def test_get_set_parameter_description(): - add = OpenAIFunction(add_with_doc) + add = FunctionTool(add_with_doc) assert add.get_paramter_description("a") == "The first number to be added." add.set_paramter_description("a", "New description for a.") assert add.get_paramter_description("a") == "New description for a." def test_get_set_parameter_description_non_existing(): - add = OpenAIFunction(add_with_doc) + add = FunctionTool(add_with_doc) with pytest.raises(KeyError): add.get_paramter_description("non_existing") def test_get_set_openai_function_schema(): - add = OpenAIFunction(add_with_doc) + add = FunctionTool(add_with_doc) initial_schema = add.get_openai_function_schema() assert initial_schema is not None @@ -400,7 +400,7 @@ def test_get_set_openai_function_schema(): def test_get_set_function_name(): - add = OpenAIFunction(add_with_doc) + add = FunctionTool(add_with_doc) assert add.get_function_name() == "add_with_doc" add.set_function_name("new_add") @@ -408,7 +408,7 @@ def test_get_set_function_name(): def test_get_set_function_description(): - add = OpenAIFunction(add_with_doc) + add = FunctionTool(add_with_doc) initial_description = add.get_function_description() assert initial_description is not None @@ -418,7 +418,7 @@ def test_get_set_function_description(): def test_get_set_parameter(): - add = OpenAIFunction(add_with_doc) + add = FunctionTool(add_with_doc) initial_param_schema = add.get_parameter("a") assert initial_param_schema is not None @@ -431,7 +431,7 @@ def test_get_set_parameter(): def test_parameters_getter_setter(): - add = OpenAIFunction(add_with_doc) + add = FunctionTool(add_with_doc) initial_params = add.parameters assert initial_params is not None diff --git a/test/toolkits/test_reddit_functions.py b/test/toolkits/test_reddit_functions.py index 149b749b64..64c1370a16 100644 --- a/test/toolkits/test_reddit_functions.py +++ b/test/toolkits/test_reddit_functions.py @@ -120,8 +120,8 @@ def test_track_keyword_discussions(reddit_toolkit): def test_get_tools(reddit_toolkit): - from camel.toolkits import OpenAIFunction + from camel.toolkits import FunctionTool tools = reddit_toolkit.get_tools() assert len(tools) == 3 - assert all(isinstance(tool, OpenAIFunction) for tool in tools) + assert all(isinstance(tool, FunctionTool) for tool in tools) diff --git a/test/toolkits/test_search_functions.py b/test/toolkits/test_search_functions.py index 71fb81b642..7f49b79de6 100644 --- a/test/toolkits/test_search_functions.py +++ b/test/toolkits/test_search_functions.py @@ -90,7 +90,7 @@ def test_google_api(): assert result.status_code == 200 -search_duckduckgo = SearchToolkit().search_duckduckgo +search_duckduckgo = SearchToolkit().get_tools()[1] def test_search_duckduckgo_text():