diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java index fc2b9451c2..dfdbd673f5 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java @@ -167,7 +167,9 @@ private Tool createTool(MLToolSpec toolSpec) { if (!toolFactories.containsKey(toolSpec.getType())) { throw new IllegalArgumentException("Tool not found: " + toolSpec.getType()); } - Tool tool = toolFactories.get(toolSpec.getType()).create(toolParams); + Tool.Factory factory = toolFactories.get(toolSpec.getType()); + factory.initClient(client); + Tool tool = factory.create(toolParams); if (toolSpec.getName() != null) { tool.setName(toolSpec.getName()); } diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 87126c3619..def2c6fe5d 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -10,16 +10,24 @@ import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; +import java.io.File; +import java.net.URL; +import java.net.URLClassLoader; import java.nio.file.Path; +import java.security.AccessController; +import java.security.PrivilegedExceptionAction; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; +import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.ServiceLoader; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.function.Supplier; +import org.apache.lucene.spatial3d.geom.Tools; import org.opensearch.action.ActionRequest; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; @@ -859,6 +867,12 @@ public Map> getResponseProces @Override public void loadExtensions(ExtensionLoader loader) { externalToolFactories = new HashMap<>(); + ServiceLoader serviceLoader = ServiceLoader.load(Tool.class, Tool.class.getClassLoader()); + for (Tool tool : serviceLoader) { + Tool.Factory factory = tool.getFactory(); + externalToolFactories.put(tool.getType(), factory); + } + for (MLCommonsExtension extension : loader.loadExtensions(MLCommonsExtension.class)) { List> toolFactories = extension.getToolFactories(); for (Tool.Factory toolFactory : toolFactories) { diff --git a/spi/src/main/java/org/opensearch/ml/common/spi/tools/Tool.java b/spi/src/main/java/org/opensearch/ml/common/spi/tools/Tool.java index e87128fb7e..8c7b55dce8 100644 --- a/spi/src/main/java/org/opensearch/ml/common/spi/tools/Tool.java +++ b/spi/src/main/java/org/opensearch/ml/common/spi/tools/Tool.java @@ -5,6 +5,7 @@ package org.opensearch.ml.common.spi.tools; +import org.opensearch.client.Client; import org.opensearch.core.action.ActionListener; import java.util.Map; @@ -107,6 +108,10 @@ default boolean useOriginalInput() { return false; } + default Factory getFactory() { + return null; + } + /** * Tool factory which can create instance of {@link Tool}. * @param The subclass this factory produces @@ -120,6 +125,10 @@ interface Factory { */ T create(Map params); + default void initClient(Client client) { + + } + /** * Get the default description of this tool. * @return the default description