diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index d6149feddc..7f691e5c0b 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -17,8 +17,10 @@ import org.junit.Before; import org.junit.Ignore; import org.junit.Rule; +import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensearch.client.Response; +import org.opensearch.client.ResponseException; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.utils.TestHelper; @@ -76,31 +78,55 @@ public void setup() throws IOException, InterruptedException { Thread.sleep(20000); } - public void testCreateConnector() throws IOException { - Response response = createConnector(completionModelConnectorEntity); - Map responseMap = parseResponseToMap(response); - assertNotNull((String) responseMap.get("connector_id")); - } - - public void testGetConnector() throws IOException { + public void testCreate_Get_DeleteConnector() throws IOException { Response response = createConnector(completionModelConnectorEntity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); + assertNotNull(connectorId); // Testing create connector + + // Testing Get connector response = TestHelper.makeRequest(client(), "GET", "/_plugins/_ml/connectors/" + connectorId, null, "", null); responseMap = parseResponseToMap(response); - assertEquals("OpenAI Connector", (String) responseMap.get("name")); - assertEquals("1", (String) responseMap.get("version")); - assertEquals("The connector to public OpenAI model service for GPT 3.5", (String) responseMap.get("description")); - assertEquals("http", (String) responseMap.get("protocol")); - } + assertEquals("OpenAI Connector", responseMap.get("name")); + assertEquals("1", responseMap.get("version")); + assertEquals("The connector to public OpenAI model service for GPT 3.5", responseMap.get("description")); + assertEquals("http", responseMap.get("protocol")); - public void testDeleteConnector() throws IOException { - Response response = createConnector(completionModelConnectorEntity); - Map responseMap = parseResponseToMap(response); - String connectorId = (String) responseMap.get("connector_id"); + // Testing delete connector response = TestHelper.makeRequest(client(), "DELETE", "/_plugins/_ml/connectors/" + connectorId, null, "", null); responseMap = parseResponseToMap(response); - assertEquals("deleted", (String) responseMap.get("result")); + assertEquals("deleted", responseMap.get("result")); + + } + + private static String maskSensitiveInfo(String input) { + // Regex to remove the whole credential object and replace it with "***" + String regex = "\"credential\":\\{.*?}"; + return input.replaceAll(regex, "\"credential\": \"***\""); + } + + @Test + public void testMaskSensitiveInfo_withCredential() { + String input = "{\"credential\":{\"username\":\"admin\",\"password\":\"secret\"}}"; + String expectedOutput = "{\"credential\": \"***\"}"; + String actualOutput = maskSensitiveInfo(input); + assertEquals(expectedOutput, actualOutput); + } + + @Test + public void testMaskSensitiveInfo_noCredential() { + String input = "{\"otherInfo\":\"someValue\"}"; + String expectedOutput = "{\"otherInfo\":\"someValue\"}"; + String actualOutput = maskSensitiveInfo(input); + assertEquals(expectedOutput, actualOutput); + } + + @Test + public void testMaskSensitiveInfo_emptyInput() { + String input = ""; + String expectedOutput = ""; + String actualOutput = maskSensitiveInfo(input); + assertEquals(expectedOutput, actualOutput); } public void testSearchConnectors_beforeCreation() throws IOException { @@ -108,7 +134,7 @@ public void testSearchConnectors_beforeCreation() throws IOException { Response response = TestHelper .makeRequest(client(), "GET", "/_plugins/_ml/connectors/_search", null, TestHelper.toHttpEntity(searchEntity), null); Map responseMap = parseResponseToMap(response); - assertEquals((Double) 0.0, (Double) ((Map) ((Map) responseMap.get("hits")).get("total")).get("value")); + assertEquals(0.0, ((Map) ((Map) responseMap.get("hits")).get("total")).get("value")); } public void testSearchConnectors_afterCreation() throws IOException { @@ -125,7 +151,7 @@ public void testSearchRemoteModels_beforeCreation() throws IOException { Response response = TestHelper .makeRequest(client(), "GET", "/_plugins/_ml/models/_search", null, TestHelper.toHttpEntity(searchEntity), null); Map responseMap = parseResponseToMap(response); - assertEquals((Double) 0.0, (Double) ((Map) ((Map) responseMap.get("hits")).get("total")).get("value")); + assertEquals(0.0, ((Map) ((Map) responseMap.get("hits")).get("total")).get("value")); } public void testSearchRemoteModels_afterCreation() throws IOException { @@ -134,7 +160,7 @@ public void testSearchRemoteModels_afterCreation() throws IOException { Response response = TestHelper .makeRequest(client(), "GET", "/_plugins/_ml/models/_search", null, TestHelper.toHttpEntity(searchEntity), null); Map responseMap = parseResponseToMap(response); - assertEquals((Double) 1.0, (Double) ((Map) ((Map) responseMap.get("hits")).get("total")).get("value")); + assertEquals(1.0, ((Map) ((Map) responseMap.get("hits")).get("total")).get("value")); } public void testSearchModelGroups_beforeCreation() throws IOException { @@ -142,7 +168,7 @@ public void testSearchModelGroups_beforeCreation() throws IOException { Response response = TestHelper .makeRequest(client(), "GET", "/_plugins/_ml/model_groups/_search", null, TestHelper.toHttpEntity(searchEntity), null); Map responseMap = parseResponseToMap(response); - assertEquals((Double) 0.0, (Double) ((Map) ((Map) responseMap.get("hits")).get("total")).get("value")); + assertEquals(0.0, ((Map) ((Map) responseMap.get("hits")).get("total")).get("value")); } public void testSearchModelGroups_afterCreation() throws IOException { @@ -151,7 +177,7 @@ public void testSearchModelGroups_afterCreation() throws IOException { Response response = TestHelper .makeRequest(client(), "GET", "/_plugins/_ml/model_groups/_search", null, TestHelper.toHttpEntity(searchEntity), null); Map responseMap = parseResponseToMap(response); - assertEquals((Double) 1.0, (Double) ((Map) ((Map) responseMap.get("hits")).get("total")).get("value")); + assertEquals(1.0, ((Map) ((Map) responseMap.get("hits")).get("total")).get("value")); } public void testSearchMLTasks_beforeCreation() throws IOException { @@ -159,7 +185,7 @@ public void testSearchMLTasks_beforeCreation() throws IOException { Response response = TestHelper .makeRequest(client(), "GET", "/_plugins/_ml/tasks/_search", null, TestHelper.toHttpEntity(searchEntity), null); Map responseMap = parseResponseToMap(response); - assertEquals((Double) 0.0, (Double) ((Map) ((Map) responseMap.get("hits")).get("total")).get("value")); + assertEquals(0.0, ((Map) ((Map) responseMap.get("hits")).get("total")).get("value")); } public void testSearchMLTasks_afterCreation() throws IOException { @@ -169,7 +195,7 @@ public void testSearchMLTasks_afterCreation() throws IOException { Response response = TestHelper .makeRequest(client(), "GET", "/_plugins/_ml/tasks/_search", null, TestHelper.toHttpEntity(searchEntity), null); Map responseMap = parseResponseToMap(response); - assertEquals((Double) 1.0, (Double) ((Map) ((Map) responseMap.get("hits")).get("total")).get("value")); + assertEquals(1.0, ((Map) ((Map) responseMap.get("hits")).get("total")).get("value")); } public void testDeployRemoteModel() throws IOException, InterruptedException { @@ -185,7 +211,7 @@ public void testDeployRemoteModel() throws IOException, InterruptedException { String modelId = (String) responseMap.get("model_id"); response = deployRemoteModel(modelId); responseMap = parseResponseToMap(response); - assertEquals("COMPLETED", (String) responseMap.get("status")); + assertEquals("COMPLETED", responseMap.get("status")); taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); } @@ -838,7 +864,12 @@ public void testCohereClassifyModel() throws IOException, InterruptedException { } public static Response createConnector(String input) throws IOException { - return TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/connectors/_create", null, TestHelper.toHttpEntity(input), null); + try { + return TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/connectors/_create", null, TestHelper.toHttpEntity(input), null); + } catch (ResponseException e) { + String sanitizedMessage = maskSensitiveInfo(e.getMessage());// Log sanitized message + throw new RuntimeException("Request failed: " + sanitizedMessage); // Re-throw sanitized exception + } } public static Response registerRemoteModel(String name, String connectorId) throws IOException {