-
Notifications
You must be signed in to change notification settings - Fork 131
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enhance batch job task management by adding default action types #3080
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
package org.opensearch.ml.engine.algorithms.remote; | ||
|
||
import static org.apache.commons.text.StringEscapeUtils.escapeJson; | ||
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.BATCH_PREDICT; | ||
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.CANCEL_BATCH_PREDICT; | ||
import static org.opensearch.ml.common.connector.HttpConnector.RESPONSE_FILTER_FIELD; | ||
import static org.opensearch.ml.common.connector.MLPreProcessFunction.CONVERT_INPUT_TO_JSON_STRING; | ||
|
@@ -19,6 +20,7 @@ | |
import java.net.URI; | ||
import java.nio.charset.Charset; | ||
import java.util.ArrayList; | ||
import java.util.Collections; | ||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
@@ -61,6 +63,9 @@ public class ConnectorUtils { | |
private static final Aws4Signer signer; | ||
public static final String SKIP_VALIDATE_MISSING_PARAMETERS = "skip_validating_missing_parameters"; | ||
|
||
public static final List<String> SUPPORTED_REMOTE_SERVERS_FOR_DEFAULT_ACTION_TYPES = List | ||
.of("sagemaker", "openai", "bedrock", "cohere"); | ||
|
||
Comment on lines
+66
to
+68
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideally if a new platform is used but not listed here, CX should still be able to GetTask and CancelTask by manually adding the actions in the connector. But seems this is not the case in this PR? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes if the connector already has the actions configured, then they can get/cancel task for any platform. Only if no action is provided, they we perform this check |
||
static { | ||
signer = Aws4Signer.create(); | ||
} | ||
|
@@ -313,4 +318,58 @@ public static SdkHttpFullRequest buildSdkRequest( | |
} | ||
return builder.build(); | ||
} | ||
|
||
public static ConnectorAction createConnectorAction(Connector connector, ConnectorAction.ActionType actionType) { | ||
Optional<ConnectorAction> batchPredictAction = connector.findAction(BATCH_PREDICT.name()); | ||
String predictEndpoint = batchPredictAction.get().getUrl(); | ||
Map<String, String> parameters = connector.getParameters() != null | ||
? new HashMap<>(connector.getParameters()) | ||
: Collections.emptyMap(); | ||
|
||
// Apply parameter substitution only if needed | ||
if (!parameters.isEmpty()) { | ||
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); | ||
predictEndpoint = substitutor.replace(predictEndpoint); | ||
} | ||
|
||
boolean isCancelAction = actionType == CANCEL_BATCH_PREDICT; | ||
|
||
// Initialize the default method and requestBody | ||
String method = "POST"; | ||
String requestBody = null; | ||
String url = ""; | ||
|
||
switch (getRemoteServerFromURL(predictEndpoint)) { | ||
case "sagemaker": | ||
url = isCancelAction | ||
? predictEndpoint.replace("CreateTransformJob", "StopTransformJob") | ||
: predictEndpoint.replace("CreateTransformJob", "DescribeTransformJob"); | ||
requestBody = "{ \"TransformJobName\" : \"${parameters.TransformJobName}\"}"; | ||
break; | ||
case "openai": | ||
case "cohere": | ||
url = isCancelAction ? predictEndpoint + "/${parameters.id}/cancel" : predictEndpoint + "/${parameters.id}"; | ||
method = isCancelAction ? "POST" : "GET"; | ||
break; | ||
case "bedrock": | ||
url = isCancelAction | ||
? predictEndpoint + "/${parameters.processedJobArn}/stop" | ||
: predictEndpoint + "/${parameters.processedJobArn}"; | ||
method = isCancelAction ? "POST" : "GET"; | ||
break; | ||
Comment on lines
+343
to
+359
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should add the default branch for this switch statement to return null. In the GetTask and CancelTask, if the ConnectorAction is null, throw an exception with meaning logs like "please provide GetTask/CancelTask action in the connector". There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh sorry I added it but I guess it got missed during refactoring. Let me add it |
||
} | ||
|
||
return ConnectorAction | ||
.builder() | ||
.actionType(actionType) | ||
.method(method) | ||
.url(url) | ||
.requestBody(requestBody) | ||
.headers(batchPredictAction.get().getHeaders()) | ||
.build(); | ||
} | ||
|
||
public static String getRemoteServerFromURL(String url) { | ||
return SUPPORTED_REMOTE_SERVERS_FOR_DEFAULT_ACTION_TYPES.stream().filter(url::contains).findFirst().orElse(""); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's add a comment about this field.