Skip to content

Commit

Permalink
Allow llmQuestion to be optional when llmMessages is used. (Issue #3… (
Browse files Browse the repository at this point in the history
…#3072)

* Allow llmQuestion to be optional when llmMessages is used.  (Issue #3067)

Signed-off-by: Austin Lee <[email protected]>

* Remove unused lines.

Signed-off-by: Austin Lee <[email protected]>

---------

Signed-off-by: Austin Lee <[email protected]>
  • Loading branch information
austintlee authored Oct 9, 2024
1 parent 74c211e commit 48d275d
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,6 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase {
+ " \"ext\": {\n"
+ " \"generative_qa_parameters\": {\n"
+ " \"llm_model\": \"%s\",\n"
+ " \"llm_question\": \"%s\",\n"
+ " \"system_prompt\": \"%s\",\n"
+ " \"user_instructions\": \"%s\",\n"
+ " \"context_size\": %d,\n"
Expand All @@ -378,8 +377,6 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase {
+ " \"ext\": {\n"
+ " \"generative_qa_parameters\": {\n"
+ " \"llm_model\": \"%s\",\n"
+ " \"llm_question\": \"%s\",\n"
// + " \"system_prompt\": \"%s\",\n"
+ " \"user_instructions\": \"%s\",\n"
+ " \"context_size\": %d,\n"
+ " \"message_size\": %d,\n"
Expand Down Expand Up @@ -723,8 +720,12 @@ public void testBM25WithBedrock() throws Exception {
public void testBM25WithBedrockConverse() throws Exception {
// Skip test if key is null
if (AWS_ACCESS_KEY_ID == null) {
System.out.println("Skipping testBM25WithBedrockConverse because AWS_ACCESS_KEY_ID is null");
return;
}

System.out.println("Running testBM25WithBedrockConverse");

Response response = createConnector(BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
Expand Down Expand Up @@ -775,8 +776,11 @@ public void testBM25WithBedrockConverse() throws Exception {
public void testBM25WithBedrockConverseUsingLlmMessages() throws Exception {
// Skip test if key is null
if (AWS_ACCESS_KEY_ID == null) {
System.out.println("Skipping testBM25WithBedrockConverseUsingLlmMessages because AWS_ACCESS_KEY_ID is null");
return;
}
System.out.println("Running testBM25WithBedrockConverseUsingLlmMessages");

Response response = createConnector(BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT2);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
Expand Down Expand Up @@ -835,8 +839,11 @@ public void testBM25WithBedrockConverseUsingLlmMessages() throws Exception {
public void testBM25WithBedrockConverseUsingLlmMessagesForDocumentChat() throws Exception {
// Skip test if key is null
if (AWS_ACCESS_KEY_ID == null) {
System.out.println("Skipping testBM25WithBedrockConverseUsingLlmMessagesForDocumentChat because AWS_ACCESS_KEY_ID is null");
return;
}

System.out.println("Running testBM25WithBedrockConverseUsingLlmMessagesForDocumentChat");
Response response = createConnector(BEDROCK_DOCUMENT_CONVERSE_CONNECTOR_BLUEPRINT2);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
Expand Down Expand Up @@ -894,8 +901,11 @@ public void testBM25WithBedrockConverseUsingLlmMessagesForDocumentChat() throws
public void testBM25WithOpenAIWithConversation() throws Exception {
// Skip test if key is null
if (OPENAI_KEY == null) {
System.out.println("Skipping testBM25WithOpenAIWithConversation because OPENAI_KEY is null");
return;
}
System.out.println("Running testBM25WithOpenAIWithConversation");

Response response = createConnector(OPENAI_CONNECTOR_BLUEPRINT);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
Expand Down Expand Up @@ -951,8 +961,11 @@ public void testBM25WithOpenAIWithConversation() throws Exception {
public void testBM25WithOpenAIWithConversationAndImage() throws Exception {
// Skip test if key is null
if (OPENAI_KEY == null) {
System.out.println("Skipping testBM25WithOpenAIWithConversationAndImage because OPENAI_KEY is null");
return;
}
System.out.println("Running testBM25WithOpenAIWithConversationAndImage");

Response response = createConnector(OPENAI_4o_CONNECTOR_BLUEPRINT);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
Expand Down Expand Up @@ -1245,7 +1258,6 @@ private Response performSearch(String indexName, String pipeline, int size, Sear
requestParameters.source,
requestParameters.match,
requestParameters.llmModel,
requestParameters.llmQuestion,
requestParameters.systemPrompt,
requestParameters.userInstructions,
requestParameters.contextSize,
Expand All @@ -1268,8 +1280,6 @@ private Response performSearch(String indexName, String pipeline, int size, Sear
requestParameters.source,
requestParameters.match,
requestParameters.llmModel,
requestParameters.llmQuestion,
// requestParameters.systemPrompt,
requestParameters.userInstructions,
requestParameters.contextSize,
requestParameters.interactionSize,
Expand Down Expand Up @@ -1309,7 +1319,6 @@ private Response performSearch(String indexName, String pipeline, int size, Sear
requestParameters.source,
requestParameters.match,
requestParameters.llmModel,
requestParameters.llmQuestion,
requestParameters.systemPrompt,
requestParameters.userInstructions,
requestParameters.contextSize,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,11 @@ public GenerativeQAParameters(
this.conversationId = conversationId;
this.llmModel = llmModel;

// TODO: keep this requirement until we can extract the question from the query or from the request processor parameters
// for question rewriting.
Preconditions.checkArgument(!Strings.isNullOrEmpty(llmQuestion), LLM_QUESTION + " must be provided.");
Preconditions
.checkArgument(
!(Strings.isNullOrEmpty(llmQuestion) && (llmMessages == null || llmMessages.isEmpty())),
"At least one of " + LLM_QUESTION + " or " + LLM_MESSAGES_FIELD + " must be provided."
);
this.llmQuestion = llmQuestion;
this.systemPrompt = systemPrompt;
this.userInstructions = userInstructions;
Expand All @@ -185,7 +187,7 @@ public GenerativeQAParameters(
public GenerativeQAParameters(StreamInput input) throws IOException {
this.conversationId = input.readOptionalString();
this.llmModel = input.readOptionalString();
this.llmQuestion = input.readString();
this.llmQuestion = input.readOptionalString();
this.systemPrompt = input.readOptionalString();
this.userInstructions = input.readOptionalString();
this.contextSize = input.readInt();
Expand Down Expand Up @@ -246,9 +248,7 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(conversationId);
out.writeOptionalString(llmModel);

Preconditions.checkNotNull(llmQuestion, "llm_question must not be null.");
out.writeString(llmQuestion);
out.writeOptionalString(llmQuestion);
out.writeOptionalString(systemPrompt);
out.writeOptionalString(userInstructions);
out.writeInt(contextSize);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,7 @@ public void testMiscMethods() throws IOException {

StreamOutput so = mock(StreamOutput.class);
builder1.writeTo(so);
verify(so, times(5)).writeOptionalString(any());
verify(so, times(1)).writeString(any());
verify(so, times(6)).writeOptionalString(any());
}

public void testParse() throws IOException {
Expand Down

0 comments on commit 48d275d

Please sign in to comment.