Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement]Adding additional info for memory metadata #2750

Merged
merged 2 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -541,4 +541,5 @@ public class CommonValue {
public static final Version VERSION_2_13_0 = Version.fromString("2.13.0");
public static final Version VERSION_2_14_0 = Version.fromString("2.14.0");
public static final Version VERSION_2_16_0 = Version.fromString("2.16.0");
public static final Version VERSION_2_17_0 = Version.fromString("2.17.0");
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,26 @@
import java.util.Map;
import java.util.Objects;

import org.opensearch.action.index.IndexRequest;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.CommonValue;
import org.opensearch.search.SearchHit;

import lombok.AllArgsConstructor;
import lombok.Getter;

import static org.opensearch.ml.common.CommonValue.VERSION_2_17_0;

/**
* Class for holding conversational metadata
*/
@AllArgsConstructor
public class ConversationMeta implements Writeable, ToXContentObject {

public static final Version MINIMAL_SUPPORTED_VERSION_FOR_ADDITIONAL_INFO = CommonValue.VERSION_2_17_0;
@Getter
private String id;
@Getter
Expand All @@ -49,6 +52,8 @@ public class ConversationMeta implements Writeable, ToXContentObject {
private String name;
@Getter
private String user;
@Getter
private Map<String, String> additionalInfos;

/**
* Creates a conversationMeta object from a SearchHit object
Expand All @@ -71,7 +76,8 @@ public static ConversationMeta fromMap(String id, Map<String, Object> docFields)
Instant updated = Instant.parse((String) docFields.get(ConversationalIndexConstants.META_UPDATED_TIME_FIELD));
String name = (String) docFields.get(ConversationalIndexConstants.META_NAME_FIELD);
String user = (String) docFields.get(ConversationalIndexConstants.USER_FIELD);
return new ConversationMeta(id, created, updated, name, user);
Map<String, String> additionalInfos = (Map<String, String>)docFields.get(ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD);
return new ConversationMeta(id, created, updated, name, user, additionalInfos);
}

/**
Expand All @@ -87,7 +93,13 @@ public static ConversationMeta fromStream(StreamInput in) throws IOException {
Instant updated = in.readInstant();
String name = in.readString();
String user = in.readOptionalString();
return new ConversationMeta(id, created, updated, name, user);
Map<String, String> additionalInfos = null;
if (in.getVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_ADDITIONAL_INFO)) {
if (in.readBoolean()) {
additionalInfos = in.readMap(StreamInput::readString, StreamInput::readString);
}
}
return new ConversationMeta(id, created, updated, name, user, additionalInfos);
}

@Override
Expand All @@ -97,6 +109,14 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeInstant(updatedTime);
out.writeString(name);
out.writeOptionalString(user);
if(out.getVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_ADDITIONAL_INFO)) {
if (additionalInfos == null) {
out.writeBoolean(false);
} else {
out.writeBoolean(true);
out.writeMap(additionalInfos, StreamOutput::writeString, StreamOutput::writeString);
}
}
}

@Override
Expand All @@ -119,6 +139,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Para
if(this.user != null) {
builder.field(ConversationalIndexConstants.USER_FIELD, this.user);
}
if (this.additionalInfos != null) {
builder.field(ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD, this.additionalInfos);
}
builder.endObject();
return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
*/
public class ConversationalIndexConstants {
/** Version of the meta index schema */
public final static Integer META_INDEX_SCHEMA_VERSION = 1;
public final static Integer META_INDEX_SCHEMA_VERSION = 2;
/** Name of the conversational metadata index */
public final static String META_INDEX_NAME = ".plugins-ml-memory-meta";
/** Name of the metadata field for initial timestamp */
Expand All @@ -37,6 +37,9 @@ public class ConversationalIndexConstants {
public final static String USER_FIELD = "user";
/** Name of the application that created this conversation */
public final static String APPLICATION_TYPE_FIELD = "application_type";
/** Name of the additional information for this memory */
public final static String META_ADDITIONAL_INFO_FIELD = "additional_info";

/** Mappings for the conversational metadata index */
public final static String META_MAPPING = "{\n"
+ " \"_meta\": {\n"
Expand All @@ -57,7 +60,10 @@ public class ConversationalIndexConstants {
+ "\": {\"type\": \"keyword\"},\n"
+ " \""
+ APPLICATION_TYPE_FIELD
+ "\": {\"type\": \"keyword\"}\n"
+ "\": {\"type\": \"keyword\"},\n"
+ " \""
+ META_ADDITIONAL_INFO_FIELD
+ "\": {\"type\": \"flat_object\"}\n"
+ " }\n"
+ "}";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.Map;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS;

public class ConversationMetaTests {
Expand All @@ -30,7 +31,7 @@ public class ConversationMetaTests {
@Before
public void setUp() {
time = Instant.now();
conversationMeta = new ConversationMeta("test_id", time, time, "test_name", "admin");
conversationMeta = new ConversationMeta("test_id", time, time, "test_name", "admin", null);
}

@Test
Expand All @@ -41,6 +42,7 @@ public void test_fromSearchHit() throws IOException {
content.field(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, time);
content.field(ConversationalIndexConstants.META_NAME_FIELD, "meta name");
content.field(ConversationalIndexConstants.USER_FIELD, "admin");
content.field(ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD, Map.of("test_key", "test_value"));
content.endObject();

SearchHit[] hits = new SearchHit[1];
Expand All @@ -50,6 +52,7 @@ public void test_fromSearchHit() throws IOException {
assertEquals(conversationMeta.getId(), "cId");
assertEquals(conversationMeta.getName(), "meta name");
assertEquals(conversationMeta.getUser(), "admin");
assertEquals(conversationMeta.getAdditionalInfos().get("test_key"), "test_value");
}

@Test
Expand Down Expand Up @@ -85,7 +88,7 @@ public void test_fromStream() throws IOException {

@Test
public void test_ToXContent() throws IOException {
ConversationMeta conversationMeta = new ConversationMeta("test_id", Instant.ofEpochMilli(123), Instant.ofEpochMilli(123), "test meta", "admin");
ConversationMeta conversationMeta = new ConversationMeta("test_id", Instant.ofEpochMilli(123), Instant.ofEpochMilli(123), "test meta", "admin", null);
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
conversationMeta.toXContent(builder, EMPTY_PARAMS);
String content = TestHelper.xContentBuilderToString(builder);
Expand All @@ -94,13 +97,13 @@ public void test_ToXContent() throws IOException {

@Test
public void test_toString() {
ConversationMeta conversationMeta = new ConversationMeta("test_id", Instant.ofEpochMilli(123), Instant.ofEpochMilli(123), "test meta", "admin");
ConversationMeta conversationMeta = new ConversationMeta("test_id", Instant.ofEpochMilli(123), Instant.ofEpochMilli(123), "test meta", "admin", null);
assertEquals("{id=test_id, name=test meta, created=1970-01-01T00:00:00.123Z, updated=1970-01-01T00:00:00.123Z, user=admin}", conversationMeta.toString());
}

@Test
public void test_equal() {
ConversationMeta meta = new ConversationMeta("test_id", Instant.ofEpochMilli(123), Instant.ofEpochMilli(123), "test meta", "admin");
ConversationMeta meta = new ConversationMeta("test_id", Instant.ofEpochMilli(123), Instant.ofEpochMilli(123), "test meta", "admin", null);
assertEquals(meta.equals(conversationMeta), false);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,20 @@ public interface ConversationalMemoryHandler {
*/
public void createConversation(String name, String applicationType, ActionListener<String> listener);

/**
* Create a new conversation
* @param name the name of the new conversation
* @param applicationType the application that creates this conversation
* @param additionalInfos additional information associated with this conversation
* @param listener listener to wait for this op to finish, gets unique id of new conversation
*/
public void createConversation(
String name,
String applicationType,
Map<String, String> additionalInfos,
ActionListener<String> listener
);

/**
* Adds an interaction to the conversation indicated, updating the conversational metadata
* @param conversationId the conversation to add the interaction to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,19 @@
package org.opensearch.ml.memory.action.conversation;

import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.APPLICATION_TYPE_FIELD;
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD;

import java.io.IOException;
import java.util.Map;

import org.opensearch.OpenSearchParseException;
import org.opensearch.Version;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.CommonValue;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.rest.RestRequest;

Expand All @@ -36,10 +40,14 @@
* Action Request for creating a conversation
*/
public class CreateConversationRequest extends ActionRequest {
public static final Version MINIMAL_SUPPORTED_VERSION_FOR_ADDITIONAL_INFO = CommonValue.VERSION_2_17_0;

@Getter
private String name = null;
@Getter
private String applicationType = null;
@Getter
private Map<String, String> additionalInfos = null;

/**
* Constructor
Expand All @@ -50,6 +58,11 @@ public CreateConversationRequest(StreamInput in) throws IOException {
super(in);
this.name = in.readOptionalString();
this.applicationType = in.readOptionalString();
if (in.getVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_ADDITIONAL_INFO)) {
if (in.readBoolean()) {
this.additionalInfos = in.readMap(StreamInput::readString, StreamInput::readString);
}
}
}

/**
Expand All @@ -71,6 +84,19 @@ public CreateConversationRequest(String name, String applicationType) {
this.applicationType = applicationType;
}

/**
* Constructor
* @param name name of the conversation
* @param applicationType of the conversation
* @param additionalInfos information of the conversation
*/
public CreateConversationRequest(String name, String applicationType, Map<String, String> additionalInfos) {
super();
this.name = name;
this.applicationType = applicationType;
this.additionalInfos = additionalInfos;
}

/**
* Constructor
* name will be null
Expand All @@ -82,6 +108,14 @@ public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeOptionalString(name);
out.writeOptionalString(applicationType);
if (out.getVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_ADDITIONAL_INFO)) {
if (additionalInfos == null) {
out.writeBoolean(false);
} else {
out.writeBoolean(true);
out.writeMap(additionalInfos, StreamOutput::writeString, StreamOutput::writeString);
}
}
}

@Override
Expand All @@ -101,12 +135,13 @@ public static CreateConversationRequest fromRestRequest(RestRequest restRequest)
if (!restRequest.hasContent()) {
return new CreateConversationRequest();
}
try {
Map<String, String> body = restRequest.contentParser().mapStrings();
if (body.containsKey(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD)) {
try (XContentParser parser = restRequest.contentParser()) {
Map<String, Object> body = parser.map();
if (body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD) != null) {
return new CreateConversationRequest(
body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD),
body.get(APPLICATION_TYPE_FIELD)
(String) body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD),
body.get(APPLICATION_TYPE_FIELD) == null ? null : (String) body.get(APPLICATION_TYPE_FIELD),
body.get(META_ADDITIONAL_INFO_FIELD) == null ? null : (Map<String, String>) body.get(META_ADDITIONAL_INFO_FIELD)
);
} else {
return new CreateConversationRequest();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_DISABLED_MESSAGE;

import java.util.Map;

import org.opensearch.OpenSearchException;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
Expand Down Expand Up @@ -79,6 +81,7 @@ protected void doExecute(Task task, CreateConversationRequest request, ActionLis
}
String name = request.getName();
String applicationType = request.getApplicationType();
Map<String, String> additionalInfos = request.getAdditionalInfos();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
ActionListener<CreateConversationResponse> internalListener = ActionListener.runBefore(actionListener, () -> context.restore());
ActionListener<String> al = ActionListener.wrap(r -> { internalListener.onResponse(new CreateConversationResponse(r)); }, e -> {
Expand All @@ -89,7 +92,7 @@ protected void doExecute(Task task, CreateConversationRequest request, ActionLis
if (name == null) {
cmHandler.createConversation(al);
} else {
cmHandler.createConversation(name, applicationType, al);
cmHandler.createConversation(name, applicationType, additionalInfos, al);
}
} catch (Exception e) {
log.error("Failed to create new memory with name " + request.getName(), e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.memory.action.conversation;

import static org.opensearch.action.ValidateActions.addValidationError;
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD;
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_NAME_FIELD;

import java.io.ByteArrayInputStream;
Expand Down Expand Up @@ -35,7 +36,7 @@ public class UpdateConversationRequest extends ActionRequest {
private String conversationId;
private Map<String, Object> updateContent;

private static final Set<String> allowedList = new HashSet<>(Arrays.asList(META_NAME_FIELD));
private static final Set<String> allowedList = new HashSet<>(Arrays.asList(META_NAME_FIELD, META_ADDITIONAL_INFO_FIELD));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would allow update/delete anything inside META_ADDITIONAL_INFO_FIELD. I think you need to put some restrictions to not allow updating certain fields inside it like the application/memory type, etc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that makes sense. We can initially make it open and then add restrictions when the need arises.


@Builder
public UpdateConversationRequest(String conversationId, Map<String, Object> updateContent) {
Expand Down
Loading
Loading