Skip to content

Commit

Permalink
create conversation support additional info
Browse files Browse the repository at this point in the history
Signed-off-by: Hailong Cui <[email protected]>
  • Loading branch information
Hailong-am committed Jul 26, 2024
1 parent 9fd2a0b commit 8c2d61e
Show file tree
Hide file tree
Showing 10 changed files with 166 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,20 @@ public interface ConversationalMemoryHandler {
*/
public void createConversation(String name, String applicationType, ActionListener<String> listener);

/**
* Create a new conversation
* @param name the name of the new conversation
* @param applicationType the application that creates this conversation
* @param additionalInfos additional information associated with this conversation
* @param listener listener to wait for this op to finish, gets unique id of new conversation
*/
public void createConversation(
String name,
String applicationType,
Map<String, String> additionalInfos,
ActionListener<String> listener
);

/**
* Adds an interaction to the conversation indicated, updating the conversational metadata
* @param conversationId the conversation to add the interaction to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.opensearch.ml.memory.action.conversation;

import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.APPLICATION_TYPE_FIELD;
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD;

import java.io.IOException;
import java.util.Map;
Expand All @@ -27,6 +28,7 @@
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.rest.RestRequest;

Expand All @@ -40,6 +42,8 @@ public class CreateConversationRequest extends ActionRequest {
private String name = null;
@Getter
private String applicationType = null;
@Getter
private Map<String, String> additionalInfos = null;

/**
* Constructor
Expand All @@ -50,6 +54,9 @@ public CreateConversationRequest(StreamInput in) throws IOException {
super(in);
this.name = in.readOptionalString();
this.applicationType = in.readOptionalString();
if (in.readBoolean()) {
this.additionalInfos = in.readMap(StreamInput::readString, StreamInput::readString);
}
}

/**
Expand All @@ -71,6 +78,19 @@ public CreateConversationRequest(String name, String applicationType) {
this.applicationType = applicationType;
}

/**
* Constructor
* @param name name of the conversation
* @param applicationType of the conversation
* @param additionalInfos information of the conversation
*/
public CreateConversationRequest(String name, String applicationType, Map<String, String> additionalInfos) {
super();
this.name = name;
this.applicationType = applicationType;
this.additionalInfos = additionalInfos;
}

/**
* Constructor
* name will be null
Expand All @@ -82,6 +102,12 @@ public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeOptionalString(name);
out.writeOptionalString(applicationType);
if (additionalInfos == null) {
out.writeBoolean(false);
} else {
out.writeBoolean(true);
out.writeMap(additionalInfos, StreamOutput::writeString, StreamOutput::writeString);
}
}

@Override
Expand All @@ -101,12 +127,13 @@ public static CreateConversationRequest fromRestRequest(RestRequest restRequest)
if (!restRequest.hasContent()) {
return new CreateConversationRequest();
}
try {
Map<String, String> body = restRequest.contentParser().mapStrings();
if (body.containsKey(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD)) {
try (XContentParser parser = restRequest.contentParser()) {
Map<String, Object> body = parser.map();
if (body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD) != null) {
return new CreateConversationRequest(
body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD),
body.get(APPLICATION_TYPE_FIELD)
(String) body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD),
body.get(APPLICATION_TYPE_FIELD) == null ? null : (String) body.get(APPLICATION_TYPE_FIELD),
body.get(META_ADDITIONAL_INFO_FIELD) == null ? null : (Map<String, String>) body.get(META_ADDITIONAL_INFO_FIELD)
);
} else {
return new CreateConversationRequest();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_DISABLED_MESSAGE;

import java.util.Map;

import org.opensearch.OpenSearchException;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
Expand Down Expand Up @@ -79,6 +81,7 @@ protected void doExecute(Task task, CreateConversationRequest request, ActionLis
}
String name = request.getName();
String applicationType = request.getApplicationType();
Map<String, String> additionalInfos = request.getAdditionalInfos();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
ActionListener<CreateConversationResponse> internalListener = ActionListener.runBefore(actionListener, () -> context.restore());
ActionListener<String> al = ActionListener.wrap(r -> { internalListener.onResponse(new CreateConversationResponse(r)); }, e -> {
Expand All @@ -89,7 +92,7 @@ protected void doExecute(Task task, CreateConversationRequest request, ActionLis
if (name == null) {
cmHandler.createConversation(al);
} else {
cmHandler.createConversation(name, applicationType, al);
cmHandler.createConversation(name, applicationType, additionalInfos, al);
}
} catch (Exception e) {
log.error("Failed to create new memory with name " + request.getName(), e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.time.Instant;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.OpenSearchWrapperException;
Expand Down Expand Up @@ -126,9 +127,15 @@ public void initConversationMetaIndexIfAbsent(ActionListener<Boolean> listener)
* Adds a new conversation with the specified name to the index
* @param name user-specified name of the conversation to be added
* @param applicationType the application type that creates this conversation
* @param additionalInfos the additional info that creates this conversation
* @param listener listener to wait for this to finish
*/
public void createConversation(String name, String applicationType, ActionListener<String> listener) {
public void createConversation(
String name,
String applicationType,
Map<String, String> additionalInfos,
ActionListener<String> listener
) {
initConversationMetaIndexIfAbsent(ActionListener.wrap(indexExists -> {
if (indexExists) {
String userstr = getUserStrFromThreadContext();
Expand All @@ -145,7 +152,9 @@ public void createConversation(String name, String applicationType, ActionListen
ConversationalIndexConstants.USER_FIELD,
userstr == null ? null : User.parse(userstr).getName(),
ConversationalIndexConstants.APPLICATION_TYPE_FIELD,
applicationType
applicationType,
ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD,
additionalInfos == null ? Map.of() : additionalInfos
);
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
ActionListener<String> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
Expand Down Expand Up @@ -176,7 +185,7 @@ public void createConversation(String name, String applicationType, ActionListen
* @param listener listener to wait for this to finish
*/
public void createConversation(ActionListener<String> listener) {
createConversation("", "", listener);
createConversation("", "", null, listener);
}

/**
Expand All @@ -185,7 +194,7 @@ public void createConversation(ActionListener<String> listener) {
* @param listener listener to wait for this to finish
*/
public void createConversation(String name, ActionListener<String> listener) {
createConversation(name, "", listener);
createConversation(name, "", null, listener);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,23 @@ public void createConversation(String name, ActionListener<String> listener) {
* @param listener listener to wait for this op to finish, gets unique id of new conversation
*/
public void createConversation(String name, String applicationType, ActionListener<String> listener) {
conversationMetaIndex.createConversation(name, applicationType, listener);
conversationMetaIndex.createConversation(name, applicationType, null, listener);
}

/**
* Create a new conversation
* @param name the name of the new conversation
* @param applicationType the application that creates this conversation
* @param additionalInfos the additional information associated with this conversation
* @param listener listener to wait for this op to finish, gets unique id of new conversation
*/
public void createConversation(
String name,
String applicationType,
Map<String, String> additionalInfos,
ActionListener<String> listener
) {
conversationMetaIndex.createConversation(name, applicationType, additionalInfos, listener);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
package org.opensearch.ml.memory.action.conversation;

import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.APPLICATION_TYPE_FIELD;
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD;

import java.io.IOException;
import java.util.Map;

import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.rules.ExpectedException;
import org.opensearch.OpenSearchParseException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.bytes.BytesReference;
Expand Down Expand Up @@ -107,11 +108,28 @@ public void testNamedRestRequest_WithAppType() throws IOException {
}

public void testRestRequest_NullName() throws IOException {
exceptionRule.expect(OpenSearchParseException.class);
exceptionRule.expectMessage("Can't get text on a VALUE_NULL");
RestRequest req = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY)
.withContent(new BytesArray("{\"name\":null}"), MediaTypeRegistry.JSON)
.build();
CreateConversationRequest.fromRestRequest(req);
CreateConversationRequest request = CreateConversationRequest.fromRestRequest(req);
Assert.assertNull(request.getName());
}

public void testRestRequest_WithAdditionalInfo() throws IOException {
String name = "test-name";
Map<String, Object> additionalInfo = Map.of("key1", "value1", "key2", 123);
RestRequest req = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY)
.withContent(
new BytesArray(
gson.toJson(Map.of(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD, name, META_ADDITIONAL_INFO_FIELD, additionalInfo))
),
MediaTypeRegistry.JSON
)
.build();
CreateConversationRequest request = CreateConversationRequest.fromRestRequest(req);
assert (request.getName().equals(name));
Assert.assertNull(request.getApplicationType());
Assert.assertEquals("value1", request.getAdditionalInfos().get("key1"));
Assert.assertEquals(123, request.getAdditionalInfos().get("key2"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ public void setup() throws IOException {
public void testCreateConversation() {
log.info("testing create conversation transport");
doAnswer(invocation -> {
ActionListener<String> listener = invocation.getArgument(2);
ActionListener<String> listener = invocation.getArgument(3);
listener.onResponse("testID");
return null;
}).when(cmHandler).createConversation(any(), any(), any());
}).when(cmHandler).createConversation(any(), any(), any(), any());
action.doExecute(null, request, actionListener);
ArgumentCaptor<CreateConversationResponse> argCaptor = ArgumentCaptor.forClass(CreateConversationResponse.class);
verify(actionListener).onResponse(argCaptor.capture());
Expand All @@ -137,18 +137,18 @@ public void testCreateConversationWithNullName() {

public void testCreateConversationFails_thenFail() {
doAnswer(invocation -> {
ActionListener<String> listener = invocation.getArgument(2);
ActionListener<String> listener = invocation.getArgument(3);
listener.onFailure(new Exception("Testing Error"));
return null;
}).when(cmHandler).createConversation(any(), any(), any());
}).when(cmHandler).createConversation(any(), any(), any(), any());
action.doExecute(null, request, actionListener);
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argCaptor.capture());
assert (argCaptor.getValue().getMessage().equals("Testing Error"));
}

public void testDoExecuteFails_thenFail() {
doThrow(new RuntimeException("Test doExecute Error")).when(cmHandler).createConversation(any(), any(), any());
doThrow(new RuntimeException("Test doExecute Error")).when(cmHandler).createConversation(any(), any(), any(), any());
action.doExecute(null, request, actionListener);
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argCaptor.capture());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,21 +67,21 @@ public void testToXContent() throws IOException {
}

public void testToXContent_withAdditionalInfo() throws IOException {
Map<String,String> additionalInfos = Map.of("key1", "value1");
Map<String, String> additionalInfos = Map.of("key1", "value1");
ConversationMeta convo = new ConversationMeta("cid", Instant.now(), Instant.now(), "name", null, additionalInfos);
GetConversationResponse response = new GetConversationResponse(convo);
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
response.toXContent(builder, ToXContent.EMPTY_PARAMS);
String result = BytesReference.bytes(builder).utf8ToString();
String expected = "{\"memory_id\":\"cid\",\"create_time\":\""
+ convo.getCreatedTime()
+ "\",\"updated_time\":\""
+ convo.getUpdatedTime()
+ "\",\"name\":\"name\""
+ ",\"additional_info\":{\"key1\":\"value1\"}"
+ "}";
+ convo.getCreatedTime()
+ "\",\"updated_time\":\""
+ convo.getUpdatedTime()
+ "\",\"name\":\"name\""
+ ",\"additional_info\":{\"key1\":\"value1\"}"
+ "}";
// Sometimes there's an extra trailing 0 in the time stringification, so just assert closeness
LevenshteinDistance ld = new LevenshteinDistance();
Assert.assertTrue (ld.getDistance(result, expected) > 0.95);
Assert.assertTrue(ld.getDistance(result, expected) > 0.95);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Stack;
import java.util.concurrent.CountDownLatch;
import java.util.function.Consumer;

import org.junit.Assert;
import org.junit.Before;
import org.junit.Ignore;
import org.opensearch.OpenSearchStatusException;
Expand Down Expand Up @@ -561,6 +563,8 @@ public void testCanGetAConversationById() {
assert (cid2.result().equals(get2.result().getId()));
assert (get1.result().getName().equals("convo1"));
assert (get2.result().getName().equals("convo2"));
Assert.assertTrue(convo2.getAdditionalInfos().isEmpty());
Assert.assertTrue(get1.result().getAdditionalInfos().isEmpty());
cdl.countDown();
}, e -> {
cdl.countDown();
Expand Down Expand Up @@ -634,4 +638,37 @@ public void testCanGetAConversationByIdSecurely() {
}
}

public void testCanCreateConversationWithAdditionalInfo() {
CountDownLatch cdl = new CountDownLatch(1);
StepListener<String> cid1 = new StepListener<>();
index.createConversation("hailong-convo", "app", Map.of("k", "v"), cid1);

StepListener<ConversationMeta> get1 = new StepListener<>();
cid1.whenComplete(cid -> { index.getConversation(cid1.result(), get1); }, e -> {
cdl.countDown();
log.error(e);
assert (false);
});

get1.whenComplete(convo1 -> {
try {
Assert.assertEquals(cid1.result(), convo1.getId());
Assert.assertEquals("hailong-convo", convo1.getName());
Assert.assertNotNull(convo1.getAdditionalInfos());
Assert.assertEquals("v", convo1.getAdditionalInfos().get("k"));
} finally {
cdl.countDown();
}
}, e -> {
cdl.countDown();
log.error(e);
assert (false);
});

try {
cdl.await();
} catch (InterruptedException e) {
log.error(e);
}
}
}
Loading

0 comments on commit 8c2d61e

Please sign in to comment.