Skip to content

Commit

Permalink
add query parameter to retrieve best* or random prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
ili16 committed Feb 29, 2024
1 parent f0dbb99 commit 8cc8a6d
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 112 deletions.
20 changes: 16 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ func main() {

router.GET("/weaviate/retrieveresponse", func(c *gin.Context) {
searchQuery := c.Query("query")
upvoteStr := c.Query("best")
best := upvoteStr == "true"

// Decode the search query
decodedQuery, err := url.QueryUnescape(searchQuery)
Expand All @@ -55,10 +57,20 @@ func main() {

log.Printf("Decoded Query: %s", decodedQuery)

response, err := weaviate.RetrieveRandomResponse(decodedQuery)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
var response interface{}

if best {
response, err = weaviate.RetrieveBestResponse(decodedQuery)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
} else {
response, err = weaviate.RetrieveRandomResponse(decodedQuery)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}

c.JSON(http.StatusOK, response)
Expand Down
197 changes: 89 additions & 108 deletions weaviate/weaviate_retrieval.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/weaviate/weaviate-go-client/v4/weaviate/filters"
"github.com/weaviate/weaviate-go-client/v4/weaviate/graphql"
"log"
"github.com/weaviate/weaviate/entities/models"
"math/rand"
"time"
)
Expand Down Expand Up @@ -120,75 +119,42 @@ func RetrievePromptCount(code string) (int, error) {
return int(countFloat), nil
}

func RetrieveBestResponse(code string) (string, error) {
func RetrieveBestResponse(code string) (ResponseData, error) {

client, err := loadClient()
responses, err := RetrieveResponsesRankDesc(code)
if err != nil {
return "", err
}

fields := []graphql.Field{
{Name: "instruct"},
{Name: "rank"},
{Name: "hasResponse", Fields: []graphql.Field{
{Name: "... on Response", Fields: []graphql.Field{
{Name: "response"},
}},
}},
}

where := filters.Where().
WithPath([]string{"code"}).
WithOperator(filters.Like).
WithValueText(code)

byRankDesc := graphql.Sort{
Path: []string{"rank"}, Order: graphql.Desc,
}

ctx := context.Background()
result, err := client.GraphQL().Get().
WithClassName("Prompt").
WithSort(byRankDesc).
WithFields(fields...).
WithWhere(where).
Do(ctx)
if err != nil {
panic(err)
return ResponseData{}, err
}

getPrompt, ok := result.Data["Get"].(map[string]interface{})
getPrompt, ok := responses.Data["Get"].(map[string]interface{})
if !ok {
return "", errors.New("unexpected response format: 'Get' field not found or not a map")
return ResponseData{}, errors.New("unexpected response format: 'Get' field not found or not a map")
}

promptData, ok := getPrompt["Prompt"].([]interface{})
if !ok || len(promptData) == 0 {
return "", errors.New("unexpected response format: 'Prompt' field not found or empty list")
return ResponseData{}, errors.New("unexpected response format: 'Prompt' field not found or empty list")
}

// Initialize variables to track the prompt with the highest rank
var highestRank int
var highestRankPrompts []map[string]interface{}

// Iterate through each prompt to find the one with the highest rank
for _, prompt := range promptData {
promptMap, ok := prompt.(map[string]interface{})
if !ok {
return "", errors.New("unexpected response format: prompt data is not a map")
return ResponseData{}, errors.New("unexpected response format: prompt data is not a map")
}

rankInterface, ok := promptMap["rank"]
if !ok {
return "", errors.New("rank field not found in prompt data")
return ResponseData{}, errors.New("rank field not found in prompt data")
}

rank, ok := rankInterface.(float64)
if !ok {
return "", errors.New("rank field is not a number")
return ResponseData{}, errors.New("rank field is not a number")
}

// Convert float64 to int
rankInt := int(rank)

if rankInt > highestRank {
Expand All @@ -199,51 +165,85 @@ func RetrieveBestResponse(code string) (string, error) {
}
}

// If there are prompts with the same highest rank, select one randomly
if len(highestRankPrompts) > 0 {

source := rand.NewSource(time.Now().UnixNano())
rng := rand.New(source)
randomIndex := rng.Intn(len(highestRankPrompts))
selectedPrompt := highestRankPrompts[randomIndex]

hasResponse, ok := selectedPrompt["hasResponse"].([]interface{})
if !ok || len(hasResponse) == 0 {
return "", errors.New("hasResponse field not found in prompt data or empty list")
response, err := ExtractResponse(selectedPrompt)
if err != nil {
return ResponseData{}, err
}

firstResponseMap, ok := hasResponse[0].(map[string]interface{})
if !ok {
return "", errors.New("unexpected response format: response data is not a map")
id, err := ExtractID(selectedPrompt)
if err != nil {
return ResponseData{}, err
}

response, ok := firstResponseMap["response"].(string)
if !ok {
return "", errors.New("response field not found in response data or not a string")
responseData := ResponseData{
ID: id,
Response: response,
}

jsonData, err := json.Marshal(response)
if err != nil {
fmt.Println("Error:", err)
return "", err
}
return responseData, nil
}

log.Printf("Selected Response: %v\n", response)
return ResponseData{}, errors.New("no prompt found")

// Add a newline character to the end of the string
jsonDataWithNewline := append(jsonData, '\n')
}

return string(jsonDataWithNewline), nil
func RetrieveRandomResponse(code string) (ResponseData, error) {

responses, err := RetrieveResponsesRankDesc(code)
if err != nil {
return ResponseData{}, err
}

getPrompt, ok := responses.Data["Get"].(map[string]interface{})
if !ok {
return ResponseData{}, errors.New("unexpected response format: 'Get' field not found or not a map")
}

return "", errors.New("no prompt found")
promptData, ok := getPrompt["Prompt"].([]interface{})
if !ok || len(promptData) == 0 {
return ResponseData{}, errors.New("unexpected response format: 'Prompt' field not found or empty list")
}

source := rand.NewSource(time.Now().UnixNano())
rng := rand.New(source)
randomIndex := rng.Intn(len(promptData))
selectedPrompt := promptData[randomIndex]

selectedPromptMap, ok := selectedPrompt.(map[string]interface{})
if !ok {
return ResponseData{}, errors.New("unexpected response format: selected prompt data is not a map")
}

response, err := ExtractResponse(selectedPromptMap)
if err != nil {
return ResponseData{}, err
}

id, err := ExtractID(selectedPromptMap)
if err != nil {
return ResponseData{}, err
}

responseData := ResponseData{
ID: id,
Response: response,
}

return responseData, nil
}

func RetrieveRandomResponse(code string) (ResponseData, error) {
func RetrieveResponsesRankDesc(code string) (*models.GraphQLResponse, error) {

client, err := loadClient()
if err != nil {
return ResponseData{}, err
return nil, err
}

fields := []graphql.Field{
Expand All @@ -269,70 +269,51 @@ func RetrieveRandomResponse(code string) (ResponseData, error) {
WithWhere(where).
Do(ctx)
if err != nil {
panic(err)
return nil, err
}

getPrompt, ok := result.Data["Get"].(map[string]interface{})
return result, nil
}

func ExtractID(selectedPrompt map[string]interface{}) (string, error) {
hasAdditionalInterface, ok := selectedPrompt["_additional"]
if !ok {
return ResponseData{}, errors.New("unexpected response format: 'Get' field not found or not a map")
return "", errors.New("_additional field not found in prompt data")
}

promptData, ok := getPrompt["Prompt"].([]interface{})
if !ok || len(promptData) == 0 {
return ResponseData{}, errors.New("unexpected response format: 'Prompt' field not found or empty list")
additionalMap, ok := hasAdditionalInterface.(map[string]interface{})
if !ok {
return "", errors.New("_additional field is not a map in prompt data")
}

source := rand.NewSource(time.Now().UnixNano())
rng := rand.New(source)
randomIndex := rng.Intn(len(promptData))
selectedPrompt := promptData[randomIndex]
idInterface, ok := additionalMap["id"]
if !ok {
return "", errors.New("id field not found in _additional data")
}

selectedPromptMap, ok := selectedPrompt.(map[string]interface{})
id, ok := idInterface.(string)
if !ok {
return ResponseData{}, errors.New("unexpected response format: selected prompt data is not a map")
return "", errors.New("id field is not a string in _additional data")
}

log.Printf("selectedPromptMap: %v\n", selectedPromptMap)
return id, nil
}

func ExtractResponse(selectedPromptMap map[string]interface{}) (string, error) {
hasResponse, ok := selectedPromptMap["hasResponse"].([]interface{})
if !ok || len(hasResponse) == 0 {
return ResponseData{}, errors.New("hasResponse field not found in prompt data or empty list")
return "", errors.New("hasResponse field not found in prompt data or empty list")
}

firstResponseMap, ok := hasResponse[0].(map[string]interface{})
if !ok {
return ResponseData{}, errors.New("unexpected response format: response data is not a map")
return "", errors.New("unexpected response format: response data is not a map")
}

response, ok := firstResponseMap["response"].(string)
if !ok {
return ResponseData{}, errors.New("response field not found in response data or not a string")
}

hasAdditionalInterface, ok := selectedPromptMap["_additional"]
if !ok {
return ResponseData{}, errors.New("_additional field not found in prompt data")
}

additionalMap, ok := hasAdditionalInterface.(map[string]interface{})
if !ok {
return ResponseData{}, errors.New("_additional field is not a map in prompt data")
return "", errors.New("response field not found in response data or not a string")
}

idInterface, ok := additionalMap["id"]
if !ok {
return ResponseData{}, errors.New("id field not found in _additional data")
}

id, ok := idInterface.(string)
if !ok {
return ResponseData{}, errors.New("id field is not a string in _additional data")
}

responseData := ResponseData{
ID: id,
Response: response,
}

return responseData, nil
return response, nil
}

0 comments on commit 8cc8a6d

Please sign in to comment.