From 8c2d61ea71ac12b4c8adc1593dde6b77ba3da4d9 Mon Sep 17 00:00:00 2001 From: Hailong Cui Date: Fri, 26 Jul 2024 14:27:40 +0800 Subject: [PATCH] create conversation support additional info Signed-off-by: Hailong Cui --- .../memory/ConversationalMemoryHandler.java | 14 +++++++ .../CreateConversationRequest.java | 37 ++++++++++++++++--- .../CreateConversationTransportAction.java | 5 ++- .../memory/index/ConversationMetaIndex.java | 17 +++++++-- ...OpenSearchConversationalMemoryHandler.java | 18 ++++++++- .../CreateConversationRequestTests.java | 26 +++++++++++-- ...reateConversationTransportActionTests.java | 10 ++--- .../GetConversationResponseTests.java | 16 ++++---- .../index/ConversationMetaIndexITTests.java | 37 +++++++++++++++++++ ...earchConversationalMemoryHandlerTests.java | 15 +++++++- 10 files changed, 166 insertions(+), 29 deletions(-) diff --git a/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java b/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java index b553a222a9..aa24016b96 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java +++ b/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java @@ -68,6 +68,20 @@ public interface ConversationalMemoryHandler { */ public void createConversation(String name, String applicationType, ActionListener listener); + /** + * Create a new conversation + * @param name the name of the new conversation + * @param applicationType the application that creates this conversation + * @param additionalInfos additional information associated with this conversation + * @param listener listener to wait for this op to finish, gets unique id of new conversation + */ + public void createConversation( + String name, + String applicationType, + Map additionalInfos, + ActionListener listener + ); + /** * Adds an interaction to the conversation indicated, updating the conversational metadata * @param conversationId the conversation to add the interaction to diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java index 950a2c5a88..a02c71b9d8 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java @@ -18,6 +18,7 @@ package org.opensearch.ml.memory.action.conversation; import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.APPLICATION_TYPE_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD; import java.io.IOException; import java.util.Map; @@ -27,6 +28,7 @@ import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.conversation.ActionConstants; import org.opensearch.rest.RestRequest; @@ -40,6 +42,8 @@ public class CreateConversationRequest extends ActionRequest { private String name = null; @Getter private String applicationType = null; + @Getter + private Map additionalInfos = null; /** * Constructor @@ -50,6 +54,9 @@ public CreateConversationRequest(StreamInput in) throws IOException { super(in); this.name = in.readOptionalString(); this.applicationType = in.readOptionalString(); + if (in.readBoolean()) { + this.additionalInfos = in.readMap(StreamInput::readString, StreamInput::readString); + } } /** @@ -71,6 +78,19 @@ public CreateConversationRequest(String name, String applicationType) { this.applicationType = applicationType; } + /** + * Constructor + * @param name name of the conversation + * @param applicationType of the conversation + * @param additionalInfos information of the conversation + */ + public CreateConversationRequest(String name, String applicationType, Map additionalInfos) { + super(); + this.name = name; + this.applicationType = applicationType; + this.additionalInfos = additionalInfos; + } + /** * Constructor * name will be null @@ -82,6 +102,12 @@ public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeOptionalString(name); out.writeOptionalString(applicationType); + if (additionalInfos == null) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + out.writeMap(additionalInfos, StreamOutput::writeString, StreamOutput::writeString); + } } @Override @@ -101,12 +127,13 @@ public static CreateConversationRequest fromRestRequest(RestRequest restRequest) if (!restRequest.hasContent()) { return new CreateConversationRequest(); } - try { - Map body = restRequest.contentParser().mapStrings(); - if (body.containsKey(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD)) { + try (XContentParser parser = restRequest.contentParser()) { + Map body = parser.map(); + if (body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD) != null) { return new CreateConversationRequest( - body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD), - body.get(APPLICATION_TYPE_FIELD) + (String) body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD), + body.get(APPLICATION_TYPE_FIELD) == null ? null : (String) body.get(APPLICATION_TYPE_FIELD), + body.get(META_ADDITIONAL_INFO_FIELD) == null ? null : (Map) body.get(META_ADDITIONAL_INFO_FIELD) ); } else { return new CreateConversationRequest(); diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportAction.java index 0a882b00dd..11bc569d00 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportAction.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportAction.java @@ -19,6 +19,8 @@ import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_DISABLED_MESSAGE; +import java.util.Map; + import org.opensearch.OpenSearchException; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; @@ -79,6 +81,7 @@ protected void doExecute(Task task, CreateConversationRequest request, ActionLis } String name = request.getName(); String applicationType = request.getApplicationType(); + Map additionalInfos = request.getAdditionalInfos(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { ActionListener internalListener = ActionListener.runBefore(actionListener, () -> context.restore()); ActionListener al = ActionListener.wrap(r -> { internalListener.onResponse(new CreateConversationResponse(r)); }, e -> { @@ -89,7 +92,7 @@ protected void doExecute(Task task, CreateConversationRequest request, ActionLis if (name == null) { cmHandler.createConversation(al); } else { - cmHandler.createConversation(name, applicationType, al); + cmHandler.createConversation(name, applicationType, additionalInfos, al); } } catch (Exception e) { log.error("Failed to create new memory with name " + request.getName(), e); diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java index 21fba2bf9d..06f1a34746 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java @@ -24,6 +24,7 @@ import java.time.Instant; import java.util.LinkedList; import java.util.List; +import java.util.Map; import org.opensearch.OpenSearchStatusException; import org.opensearch.OpenSearchWrapperException; @@ -126,9 +127,15 @@ public void initConversationMetaIndexIfAbsent(ActionListener listener) * Adds a new conversation with the specified name to the index * @param name user-specified name of the conversation to be added * @param applicationType the application type that creates this conversation + * @param additionalInfos the additional info that creates this conversation * @param listener listener to wait for this to finish */ - public void createConversation(String name, String applicationType, ActionListener listener) { + public void createConversation( + String name, + String applicationType, + Map additionalInfos, + ActionListener listener + ) { initConversationMetaIndexIfAbsent(ActionListener.wrap(indexExists -> { if (indexExists) { String userstr = getUserStrFromThreadContext(); @@ -145,7 +152,9 @@ public void createConversation(String name, String applicationType, ActionListen ConversationalIndexConstants.USER_FIELD, userstr == null ? null : User.parse(userstr).getName(), ConversationalIndexConstants.APPLICATION_TYPE_FIELD, - applicationType + applicationType, + ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD, + additionalInfos == null ? Map.of() : additionalInfos ); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); @@ -176,7 +185,7 @@ public void createConversation(String name, String applicationType, ActionListen * @param listener listener to wait for this to finish */ public void createConversation(ActionListener listener) { - createConversation("", "", listener); + createConversation("", "", null, listener); } /** @@ -185,7 +194,7 @@ public void createConversation(ActionListener listener) { * @param listener listener to wait for this to finish */ public void createConversation(String name, ActionListener listener) { - createConversation(name, "", listener); + createConversation(name, "", null, listener); } /** diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java b/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java index 89a128e0f3..755d207a86 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java @@ -103,7 +103,23 @@ public void createConversation(String name, ActionListener listener) { * @param listener listener to wait for this op to finish, gets unique id of new conversation */ public void createConversation(String name, String applicationType, ActionListener listener) { - conversationMetaIndex.createConversation(name, applicationType, listener); + conversationMetaIndex.createConversation(name, applicationType, null, listener); + } + + /** + * Create a new conversation + * @param name the name of the new conversation + * @param applicationType the application that creates this conversation + * @param additionalInfos the additional information associated with this conversation + * @param listener listener to wait for this op to finish, gets unique id of new conversation + */ + public void createConversation( + String name, + String applicationType, + Map additionalInfos, + ActionListener listener + ) { + conversationMetaIndex.createConversation(name, applicationType, additionalInfos, listener); } /** diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java index dd50f8fffe..0f2dd2b5ce 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java @@ -18,14 +18,15 @@ package org.opensearch.ml.memory.action.conversation; import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.APPLICATION_TYPE_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD; import java.io.IOException; import java.util.Map; +import org.junit.Assert; import org.junit.Before; import org.junit.Rule; import org.junit.rules.ExpectedException; -import org.opensearch.OpenSearchParseException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.common.bytes.BytesReference; @@ -107,11 +108,28 @@ public void testNamedRestRequest_WithAppType() throws IOException { } public void testRestRequest_NullName() throws IOException { - exceptionRule.expect(OpenSearchParseException.class); - exceptionRule.expectMessage("Can't get text on a VALUE_NULL"); RestRequest req = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withContent(new BytesArray("{\"name\":null}"), MediaTypeRegistry.JSON) .build(); - CreateConversationRequest.fromRestRequest(req); + CreateConversationRequest request = CreateConversationRequest.fromRestRequest(req); + Assert.assertNull(request.getName()); + } + + public void testRestRequest_WithAdditionalInfo() throws IOException { + String name = "test-name"; + Map additionalInfo = Map.of("key1", "value1", "key2", 123); + RestRequest req = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withContent( + new BytesArray( + gson.toJson(Map.of(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD, name, META_ADDITIONAL_INFO_FIELD, additionalInfo)) + ), + MediaTypeRegistry.JSON + ) + .build(); + CreateConversationRequest request = CreateConversationRequest.fromRestRequest(req); + assert (request.getName().equals(name)); + Assert.assertNull(request.getApplicationType()); + Assert.assertEquals("value1", request.getAdditionalInfos().get("key1")); + Assert.assertEquals(123, request.getAdditionalInfos().get("key2")); } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportActionTests.java index c6df207d16..1946bd3d8f 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportActionTests.java @@ -111,10 +111,10 @@ public void setup() throws IOException { public void testCreateConversation() { log.info("testing create conversation transport"); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(3); listener.onResponse("testID"); return null; - }).when(cmHandler).createConversation(any(), any(), any()); + }).when(cmHandler).createConversation(any(), any(), any(), any()); action.doExecute(null, request, actionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(CreateConversationResponse.class); verify(actionListener).onResponse(argCaptor.capture()); @@ -137,10 +137,10 @@ public void testCreateConversationWithNullName() { public void testCreateConversationFails_thenFail() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(3); listener.onFailure(new Exception("Testing Error")); return null; - }).when(cmHandler).createConversation(any(), any(), any()); + }).when(cmHandler).createConversation(any(), any(), any(), any()); action.doExecute(null, request, actionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argCaptor.capture()); @@ -148,7 +148,7 @@ public void testCreateConversationFails_thenFail() { } public void testDoExecuteFails_thenFail() { - doThrow(new RuntimeException("Test doExecute Error")).when(cmHandler).createConversation(any(), any(), any()); + doThrow(new RuntimeException("Test doExecute Error")).when(cmHandler).createConversation(any(), any(), any(), any()); action.doExecute(null, request, actionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argCaptor.capture()); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java index 1e8c1bcc58..0b39d546f8 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java @@ -67,21 +67,21 @@ public void testToXContent() throws IOException { } public void testToXContent_withAdditionalInfo() throws IOException { - Map additionalInfos = Map.of("key1", "value1"); + Map additionalInfos = Map.of("key1", "value1"); ConversationMeta convo = new ConversationMeta("cid", Instant.now(), Instant.now(), "name", null, additionalInfos); GetConversationResponse response = new GetConversationResponse(convo); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); response.toXContent(builder, ToXContent.EMPTY_PARAMS); String result = BytesReference.bytes(builder).utf8ToString(); String expected = "{\"memory_id\":\"cid\",\"create_time\":\"" - + convo.getCreatedTime() - + "\",\"updated_time\":\"" - + convo.getUpdatedTime() - + "\",\"name\":\"name\"" - + ",\"additional_info\":{\"key1\":\"value1\"}" - + "}"; + + convo.getCreatedTime() + + "\",\"updated_time\":\"" + + convo.getUpdatedTime() + + "\",\"name\":\"name\"" + + ",\"additional_info\":{\"key1\":\"value1\"}" + + "}"; // Sometimes there's an extra trailing 0 in the time stringification, so just assert closeness LevenshteinDistance ld = new LevenshteinDistance(); - Assert.assertTrue (ld.getDistance(result, expected) > 0.95); + Assert.assertTrue(ld.getDistance(result, expected) > 0.95); } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java index 74ee62bb73..09bb03eb61 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java @@ -20,11 +20,13 @@ import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.Stack; import java.util.concurrent.CountDownLatch; import java.util.function.Consumer; +import org.junit.Assert; import org.junit.Before; import org.junit.Ignore; import org.opensearch.OpenSearchStatusException; @@ -561,6 +563,8 @@ public void testCanGetAConversationById() { assert (cid2.result().equals(get2.result().getId())); assert (get1.result().getName().equals("convo1")); assert (get2.result().getName().equals("convo2")); + Assert.assertTrue(convo2.getAdditionalInfos().isEmpty()); + Assert.assertTrue(get1.result().getAdditionalInfos().isEmpty()); cdl.countDown(); }, e -> { cdl.countDown(); @@ -634,4 +638,37 @@ public void testCanGetAConversationByIdSecurely() { } } + public void testCanCreateConversationWithAdditionalInfo() { + CountDownLatch cdl = new CountDownLatch(1); + StepListener cid1 = new StepListener<>(); + index.createConversation("hailong-convo", "app", Map.of("k", "v"), cid1); + + StepListener get1 = new StepListener<>(); + cid1.whenComplete(cid -> { index.getConversation(cid1.result(), get1); }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + get1.whenComplete(convo1 -> { + try { + Assert.assertEquals(cid1.result(), convo1.getId()); + Assert.assertEquals("hailong-convo", convo1.getName()); + Assert.assertNotNull(convo1.getAdditionalInfos()); + Assert.assertEquals("v", convo1.getAdditionalInfos().get("k")); + } finally { + cdl.countDown(); + } + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + try { + cdl.await(); + } catch (InterruptedException e) { + log.error(e); + } + } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java index bc15f71f70..0fe96e90ce 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java @@ -29,6 +29,8 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -74,7 +76,7 @@ public void testCreateConversation_NoName_FutureSuccess() { assert (result.actionGet(200).equals("cid")); } - public void testCreateConversation_Named_FutureSucess() { + public void testCreateConversation_Named_FutureSuccess() { doAnswer(invocation -> { ActionListener al = invocation.getArgument(1); al.onResponse("cid"); @@ -84,6 +86,17 @@ public void testCreateConversation_Named_FutureSucess() { assert (result.actionGet(200).equals("cid")); } + public void testCreateConversation_AdditionalInfo_Success() throws Exception { + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(3); + al.onResponse("cid"); + return null; + }).when(conversationMetaIndex).createConversation(anyString(), anyString(), any(), any()); + CompletableFuture future = new CompletableFuture<>(); + cmHandler.createConversation("FutureSuccess", "", Map.of(), ActionListener.wrap(future::complete, future::completeExceptionally)); + assert (future.get(200, TimeUnit.MILLISECONDS).equals("cid")); + } + public void testCreateInteraction_Future() { doAnswer(invocation -> { ActionListener al = invocation.getArgument(7);