Skip to content

Commit

Permalink
change retrieval function to respond with ID in order to allow voting
Browse files Browse the repository at this point in the history
  • Loading branch information
ili16 committed Feb 27, 2024
1 parent 8a74389 commit 392e0bf
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 7 deletions.
4 changes: 2 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ func main() {

log.Printf("Decoded Query: %s", decodedQuery)

response, err := weaviate.RetrieveResponse(decodedQuery)
response, err := weaviate.RetrieveRandomResponse(decodedQuery)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}

c.Data(http.StatusOK, "text/plain; charset=utf-8", []byte(response))
c.JSON(http.StatusOK, response)
})

router.POST("/generate", func(c *gin.Context) {
Expand Down
5 changes: 5 additions & 0 deletions weaviate/weaviate.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ func InitSchema() error {
DataType: []string{"int"},
Description: "The relative rank for this response against other ones regarding the same code",
Name: "rank",
ModuleConfig: map[string]interface{}{
"text2vec-transformers": map[string]interface{}{
"skip": true,
},
},
},
},
}
Expand Down
112 changes: 107 additions & 5 deletions weaviate/weaviate_retrieval.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ type PromptProperties struct {
Rank int `json:"rank"`
}

type ResponseData struct {
ID string `json:"id"`
Response string `json:"response"`
}

func RetrieveProperties(id string) (PromptProperties, error) {

client, err := loadClient()
Expand Down Expand Up @@ -115,7 +120,7 @@ func RetrievePromptCount(code string) (int, error) {
return int(countFloat), nil
}

func RetrieveResponse(code string) (string, error) {
func RetrieveBestResponse(code string) (string, error) {

client, err := loadClient()
if err != nil {
Expand Down Expand Up @@ -152,8 +157,6 @@ func RetrieveResponse(code string) (string, error) {
panic(err)
}

log.Printf("result= %v\n", result)

getPrompt, ok := result.Data["Get"].(map[string]interface{})
if !ok {
return "", errors.New("unexpected response format: 'Get' field not found or not a map")
Expand Down Expand Up @@ -198,8 +201,10 @@ func RetrieveResponse(code string) (string, error) {

// If there are prompts with the same highest rank, select one randomly
if len(highestRankPrompts) > 0 {
rand.Seed(time.Now().UnixNano())
randomIndex := rand.Intn(len(highestRankPrompts))

source := rand.NewSource(time.Now().UnixNano())
rng := rand.New(source)
randomIndex := rng.Intn(len(highestRankPrompts))
selectedPrompt := highestRankPrompts[randomIndex]

hasResponse, ok := selectedPrompt["hasResponse"].([]interface{})
Expand Down Expand Up @@ -234,3 +239,100 @@ func RetrieveResponse(code string) (string, error) {
return "", errors.New("no prompt found")

}

func RetrieveRandomResponse(code string) (ResponseData, error) {
client, err := loadClient()
if err != nil {
return ResponseData{}, err
}

fields := []graphql.Field{
{Name: "hasResponse", Fields: []graphql.Field{
{Name: "... on Response", Fields: []graphql.Field{
{Name: "response"},
}},
}},
{Name: "_additional", Fields: []graphql.Field{
{Name: "id"},
}},
}

where := filters.Where().
WithPath([]string{"code"}).
WithOperator(filters.Like).
WithValueText(code)

ctx := context.Background()
result, err := client.GraphQL().Get().
WithClassName("Prompt").
WithFields(fields...).
WithWhere(where).
Do(ctx)
if err != nil {
panic(err)
}

getPrompt, ok := result.Data["Get"].(map[string]interface{})
if !ok {
return ResponseData{}, errors.New("unexpected response format: 'Get' field not found or not a map")
}

promptData, ok := getPrompt["Prompt"].([]interface{})
if !ok || len(promptData) == 0 {
return ResponseData{}, errors.New("unexpected response format: 'Prompt' field not found or empty list")
}

source := rand.NewSource(time.Now().UnixNano())
rng := rand.New(source)
randomIndex := rng.Intn(len(promptData))
selectedPrompt := promptData[randomIndex]

selectedPromptMap, ok := selectedPrompt.(map[string]interface{})
if !ok {
return ResponseData{}, errors.New("unexpected response format: selected prompt data is not a map")
}

log.Printf("selectedPromptMap: %v\n", selectedPromptMap)

hasResponse, ok := selectedPromptMap["hasResponse"].([]interface{})
if !ok || len(hasResponse) == 0 {
return ResponseData{}, errors.New("hasResponse field not found in prompt data or empty list")
}

firstResponseMap, ok := hasResponse[0].(map[string]interface{})
if !ok {
return ResponseData{}, errors.New("unexpected response format: response data is not a map")
}

response, ok := firstResponseMap["response"].(string)
if !ok {
return ResponseData{}, errors.New("response field not found in response data or not a string")
}

hasAdditionalInterface, ok := selectedPromptMap["_additional"]
if !ok {
return ResponseData{}, errors.New("_additional field not found in prompt data")
}

additionalMap, ok := hasAdditionalInterface.(map[string]interface{})
if !ok {
return ResponseData{}, errors.New("_additional field is not a map in prompt data")
}

idInterface, ok := additionalMap["id"]
if !ok {
return ResponseData{}, errors.New("id field not found in _additional data")
}

id, ok := idInterface.(string)
if !ok {
return ResponseData{}, errors.New("id field is not a string in _additional data")
}

responseData := ResponseData{
ID: id,
Response: response,
}

return responseData, nil
}

0 comments on commit 392e0bf

Please sign in to comment.