From 8cc8a6d20596dea1262cdb2ac80f970538c00313 Mon Sep 17 00:00:00 2001 From: ili16 Date: Thu, 29 Feb 2024 10:44:15 +0000 Subject: [PATCH] add query parameter to retrieve best* or random prompt --- main.go | 20 +++- weaviate/weaviate_retrieval.go | 197 +++++++++++++++------------------ 2 files changed, 105 insertions(+), 112 deletions(-) diff --git a/main.go b/main.go index c081e37..87145b4 100644 --- a/main.go +++ b/main.go @@ -45,6 +45,8 @@ func main() { router.GET("/weaviate/retrieveresponse", func(c *gin.Context) { searchQuery := c.Query("query") + upvoteStr := c.Query("best") + best := upvoteStr == "true" // Decode the search query decodedQuery, err := url.QueryUnescape(searchQuery) @@ -55,10 +57,20 @@ func main() { log.Printf("Decoded Query: %s", decodedQuery) - response, err := weaviate.RetrieveRandomResponse(decodedQuery) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return + var response interface{} + + if best { + response, err = weaviate.RetrieveBestResponse(decodedQuery) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } else { + response, err = weaviate.RetrieveRandomResponse(decodedQuery) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } } c.JSON(http.StatusOK, response) diff --git a/weaviate/weaviate_retrieval.go b/weaviate/weaviate_retrieval.go index 64693ef..a50f713 100644 --- a/weaviate/weaviate_retrieval.go +++ b/weaviate/weaviate_retrieval.go @@ -4,10 +4,9 @@ import ( "context" "encoding/json" "errors" - "fmt" "github.com/weaviate/weaviate-go-client/v4/weaviate/filters" "github.com/weaviate/weaviate-go-client/v4/weaviate/graphql" - "log" + "github.com/weaviate/weaviate/entities/models" "math/rand" "time" ) @@ -120,75 +119,42 @@ func RetrievePromptCount(code string) (int, error) { return int(countFloat), nil } -func RetrieveBestResponse(code string) (string, error) { +func RetrieveBestResponse(code string) (ResponseData, error) { - client, err := loadClient() + responses, err := RetrieveResponsesRankDesc(code) if err != nil { - return "", err - } - - fields := []graphql.Field{ - {Name: "instruct"}, - {Name: "rank"}, - {Name: "hasResponse", Fields: []graphql.Field{ - {Name: "... on Response", Fields: []graphql.Field{ - {Name: "response"}, - }}, - }}, - } - - where := filters.Where(). - WithPath([]string{"code"}). - WithOperator(filters.Like). - WithValueText(code) - - byRankDesc := graphql.Sort{ - Path: []string{"rank"}, Order: graphql.Desc, - } - - ctx := context.Background() - result, err := client.GraphQL().Get(). - WithClassName("Prompt"). - WithSort(byRankDesc). - WithFields(fields...). - WithWhere(where). - Do(ctx) - if err != nil { - panic(err) + return ResponseData{}, err } - getPrompt, ok := result.Data["Get"].(map[string]interface{}) + getPrompt, ok := responses.Data["Get"].(map[string]interface{}) if !ok { - return "", errors.New("unexpected response format: 'Get' field not found or not a map") + return ResponseData{}, errors.New("unexpected response format: 'Get' field not found or not a map") } promptData, ok := getPrompt["Prompt"].([]interface{}) if !ok || len(promptData) == 0 { - return "", errors.New("unexpected response format: 'Prompt' field not found or empty list") + return ResponseData{}, errors.New("unexpected response format: 'Prompt' field not found or empty list") } - // Initialize variables to track the prompt with the highest rank var highestRank int var highestRankPrompts []map[string]interface{} - // Iterate through each prompt to find the one with the highest rank for _, prompt := range promptData { promptMap, ok := prompt.(map[string]interface{}) if !ok { - return "", errors.New("unexpected response format: prompt data is not a map") + return ResponseData{}, errors.New("unexpected response format: prompt data is not a map") } rankInterface, ok := promptMap["rank"] if !ok { - return "", errors.New("rank field not found in prompt data") + return ResponseData{}, errors.New("rank field not found in prompt data") } rank, ok := rankInterface.(float64) if !ok { - return "", errors.New("rank field is not a number") + return ResponseData{}, errors.New("rank field is not a number") } - // Convert float64 to int rankInt := int(rank) if rankInt > highestRank { @@ -199,7 +165,6 @@ func RetrieveBestResponse(code string) (string, error) { } } - // If there are prompts with the same highest rank, select one randomly if len(highestRankPrompts) > 0 { source := rand.NewSource(time.Now().UnixNano()) @@ -207,43 +172,78 @@ func RetrieveBestResponse(code string) (string, error) { randomIndex := rng.Intn(len(highestRankPrompts)) selectedPrompt := highestRankPrompts[randomIndex] - hasResponse, ok := selectedPrompt["hasResponse"].([]interface{}) - if !ok || len(hasResponse) == 0 { - return "", errors.New("hasResponse field not found in prompt data or empty list") + response, err := ExtractResponse(selectedPrompt) + if err != nil { + return ResponseData{}, err } - firstResponseMap, ok := hasResponse[0].(map[string]interface{}) - if !ok { - return "", errors.New("unexpected response format: response data is not a map") + id, err := ExtractID(selectedPrompt) + if err != nil { + return ResponseData{}, err } - response, ok := firstResponseMap["response"].(string) - if !ok { - return "", errors.New("response field not found in response data or not a string") + responseData := ResponseData{ + ID: id, + Response: response, } - jsonData, err := json.Marshal(response) - if err != nil { - fmt.Println("Error:", err) - return "", err - } + return responseData, nil + } - log.Printf("Selected Response: %v\n", response) + return ResponseData{}, errors.New("no prompt found") - // Add a newline character to the end of the string - jsonDataWithNewline := append(jsonData, '\n') +} - return string(jsonDataWithNewline), nil +func RetrieveRandomResponse(code string) (ResponseData, error) { + + responses, err := RetrieveResponsesRankDesc(code) + if err != nil { + return ResponseData{}, err + } + + getPrompt, ok := responses.Data["Get"].(map[string]interface{}) + if !ok { + return ResponseData{}, errors.New("unexpected response format: 'Get' field not found or not a map") } - return "", errors.New("no prompt found") + promptData, ok := getPrompt["Prompt"].([]interface{}) + if !ok || len(promptData) == 0 { + return ResponseData{}, errors.New("unexpected response format: 'Prompt' field not found or empty list") + } + + source := rand.NewSource(time.Now().UnixNano()) + rng := rand.New(source) + randomIndex := rng.Intn(len(promptData)) + selectedPrompt := promptData[randomIndex] + + selectedPromptMap, ok := selectedPrompt.(map[string]interface{}) + if !ok { + return ResponseData{}, errors.New("unexpected response format: selected prompt data is not a map") + } + response, err := ExtractResponse(selectedPromptMap) + if err != nil { + return ResponseData{}, err + } + + id, err := ExtractID(selectedPromptMap) + if err != nil { + return ResponseData{}, err + } + + responseData := ResponseData{ + ID: id, + Response: response, + } + + return responseData, nil } -func RetrieveRandomResponse(code string) (ResponseData, error) { +func RetrieveResponsesRankDesc(code string) (*models.GraphQLResponse, error) { + client, err := loadClient() if err != nil { - return ResponseData{}, err + return nil, err } fields := []graphql.Field{ @@ -269,70 +269,51 @@ func RetrieveRandomResponse(code string) (ResponseData, error) { WithWhere(where). Do(ctx) if err != nil { - panic(err) + return nil, err } - getPrompt, ok := result.Data["Get"].(map[string]interface{}) + return result, nil +} + +func ExtractID(selectedPrompt map[string]interface{}) (string, error) { + hasAdditionalInterface, ok := selectedPrompt["_additional"] if !ok { - return ResponseData{}, errors.New("unexpected response format: 'Get' field not found or not a map") + return "", errors.New("_additional field not found in prompt data") } - promptData, ok := getPrompt["Prompt"].([]interface{}) - if !ok || len(promptData) == 0 { - return ResponseData{}, errors.New("unexpected response format: 'Prompt' field not found or empty list") + additionalMap, ok := hasAdditionalInterface.(map[string]interface{}) + if !ok { + return "", errors.New("_additional field is not a map in prompt data") } - source := rand.NewSource(time.Now().UnixNano()) - rng := rand.New(source) - randomIndex := rng.Intn(len(promptData)) - selectedPrompt := promptData[randomIndex] + idInterface, ok := additionalMap["id"] + if !ok { + return "", errors.New("id field not found in _additional data") + } - selectedPromptMap, ok := selectedPrompt.(map[string]interface{}) + id, ok := idInterface.(string) if !ok { - return ResponseData{}, errors.New("unexpected response format: selected prompt data is not a map") + return "", errors.New("id field is not a string in _additional data") } - log.Printf("selectedPromptMap: %v\n", selectedPromptMap) + return id, nil +} +func ExtractResponse(selectedPromptMap map[string]interface{}) (string, error) { hasResponse, ok := selectedPromptMap["hasResponse"].([]interface{}) if !ok || len(hasResponse) == 0 { - return ResponseData{}, errors.New("hasResponse field not found in prompt data or empty list") + return "", errors.New("hasResponse field not found in prompt data or empty list") } firstResponseMap, ok := hasResponse[0].(map[string]interface{}) if !ok { - return ResponseData{}, errors.New("unexpected response format: response data is not a map") + return "", errors.New("unexpected response format: response data is not a map") } response, ok := firstResponseMap["response"].(string) if !ok { - return ResponseData{}, errors.New("response field not found in response data or not a string") - } - - hasAdditionalInterface, ok := selectedPromptMap["_additional"] - if !ok { - return ResponseData{}, errors.New("_additional field not found in prompt data") - } - - additionalMap, ok := hasAdditionalInterface.(map[string]interface{}) - if !ok { - return ResponseData{}, errors.New("_additional field is not a map in prompt data") + return "", errors.New("response field not found in response data or not a string") } - idInterface, ok := additionalMap["id"] - if !ok { - return ResponseData{}, errors.New("id field not found in _additional data") - } - - id, ok := idInterface.(string) - if !ok { - return ResponseData{}, errors.New("id field is not a string in _additional data") - } - - responseData := ResponseData{ - ID: id, - Response: response, - } - - return responseData, nil + return response, nil }