From 392e0bf3e7a7368c521cd12b8198dc96a31959be Mon Sep 17 00:00:00 2001 From: ili16 Date: Tue, 27 Feb 2024 10:11:49 +0000 Subject: [PATCH] change retrieval function to respond with ID in order to allow voting --- main.go | 4 +- weaviate/weaviate.go | 5 ++ weaviate/weaviate_retrieval.go | 112 +++++++++++++++++++++++++++++++-- 3 files changed, 114 insertions(+), 7 deletions(-) diff --git a/main.go b/main.go index fd27b79..c081e37 100644 --- a/main.go +++ b/main.go @@ -55,13 +55,13 @@ func main() { log.Printf("Decoded Query: %s", decodedQuery) - response, err := weaviate.RetrieveResponse(decodedQuery) + response, err := weaviate.RetrieveRandomResponse(decodedQuery) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - c.Data(http.StatusOK, "text/plain; charset=utf-8", []byte(response)) + c.JSON(http.StatusOK, response) }) router.POST("/generate", func(c *gin.Context) { diff --git a/weaviate/weaviate.go b/weaviate/weaviate.go index 3cac16f..ba7a19e 100644 --- a/weaviate/weaviate.go +++ b/weaviate/weaviate.go @@ -91,6 +91,11 @@ func InitSchema() error { DataType: []string{"int"}, Description: "The relative rank for this response against other ones regarding the same code", Name: "rank", + ModuleConfig: map[string]interface{}{ + "text2vec-transformers": map[string]interface{}{ + "skip": true, + }, + }, }, }, } diff --git a/weaviate/weaviate_retrieval.go b/weaviate/weaviate_retrieval.go index b0b2e88..64693ef 100644 --- a/weaviate/weaviate_retrieval.go +++ b/weaviate/weaviate_retrieval.go @@ -19,6 +19,11 @@ type PromptProperties struct { Rank int `json:"rank"` } +type ResponseData struct { + ID string `json:"id"` + Response string `json:"response"` +} + func RetrieveProperties(id string) (PromptProperties, error) { client, err := loadClient() @@ -115,7 +120,7 @@ func RetrievePromptCount(code string) (int, error) { return int(countFloat), nil } -func RetrieveResponse(code string) (string, error) { +func RetrieveBestResponse(code string) (string, error) { client, err := loadClient() if err != nil { @@ -152,8 +157,6 @@ func RetrieveResponse(code string) (string, error) { panic(err) } - log.Printf("result= %v\n", result) - getPrompt, ok := result.Data["Get"].(map[string]interface{}) if !ok { return "", errors.New("unexpected response format: 'Get' field not found or not a map") @@ -198,8 +201,10 @@ func RetrieveResponse(code string) (string, error) { // If there are prompts with the same highest rank, select one randomly if len(highestRankPrompts) > 0 { - rand.Seed(time.Now().UnixNano()) - randomIndex := rand.Intn(len(highestRankPrompts)) + + source := rand.NewSource(time.Now().UnixNano()) + rng := rand.New(source) + randomIndex := rng.Intn(len(highestRankPrompts)) selectedPrompt := highestRankPrompts[randomIndex] hasResponse, ok := selectedPrompt["hasResponse"].([]interface{}) @@ -234,3 +239,100 @@ func RetrieveResponse(code string) (string, error) { return "", errors.New("no prompt found") } + +func RetrieveRandomResponse(code string) (ResponseData, error) { + client, err := loadClient() + if err != nil { + return ResponseData{}, err + } + + fields := []graphql.Field{ + {Name: "hasResponse", Fields: []graphql.Field{ + {Name: "... on Response", Fields: []graphql.Field{ + {Name: "response"}, + }}, + }}, + {Name: "_additional", Fields: []graphql.Field{ + {Name: "id"}, + }}, + } + + where := filters.Where(). + WithPath([]string{"code"}). + WithOperator(filters.Like). + WithValueText(code) + + ctx := context.Background() + result, err := client.GraphQL().Get(). + WithClassName("Prompt"). + WithFields(fields...). + WithWhere(where). + Do(ctx) + if err != nil { + panic(err) + } + + getPrompt, ok := result.Data["Get"].(map[string]interface{}) + if !ok { + return ResponseData{}, errors.New("unexpected response format: 'Get' field not found or not a map") + } + + promptData, ok := getPrompt["Prompt"].([]interface{}) + if !ok || len(promptData) == 0 { + return ResponseData{}, errors.New("unexpected response format: 'Prompt' field not found or empty list") + } + + source := rand.NewSource(time.Now().UnixNano()) + rng := rand.New(source) + randomIndex := rng.Intn(len(promptData)) + selectedPrompt := promptData[randomIndex] + + selectedPromptMap, ok := selectedPrompt.(map[string]interface{}) + if !ok { + return ResponseData{}, errors.New("unexpected response format: selected prompt data is not a map") + } + + log.Printf("selectedPromptMap: %v\n", selectedPromptMap) + + hasResponse, ok := selectedPromptMap["hasResponse"].([]interface{}) + if !ok || len(hasResponse) == 0 { + return ResponseData{}, errors.New("hasResponse field not found in prompt data or empty list") + } + + firstResponseMap, ok := hasResponse[0].(map[string]interface{}) + if !ok { + return ResponseData{}, errors.New("unexpected response format: response data is not a map") + } + + response, ok := firstResponseMap["response"].(string) + if !ok { + return ResponseData{}, errors.New("response field not found in response data or not a string") + } + + hasAdditionalInterface, ok := selectedPromptMap["_additional"] + if !ok { + return ResponseData{}, errors.New("_additional field not found in prompt data") + } + + additionalMap, ok := hasAdditionalInterface.(map[string]interface{}) + if !ok { + return ResponseData{}, errors.New("_additional field is not a map in prompt data") + } + + idInterface, ok := additionalMap["id"] + if !ok { + return ResponseData{}, errors.New("id field not found in _additional data") + } + + id, ok := idInterface.(string) + if !ok { + return ResponseData{}, errors.New("id field is not a string in _additional data") + } + + responseData := ResponseData{ + ID: id, + Response: response, + } + + return responseData, nil +}