diff --git a/main.go b/main.go index 6ac316d..838f6a5 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,7 @@ import ( "log" "net/http" "net/url" + "os" "strconv" "github.com/gin-gonic/gin" @@ -82,6 +83,7 @@ func main() { router.GET("/weaviate/retrieveresponselist", func(c *gin.Context) { searchQuery := c.Query("query") + instructType := c.Query("instructType") decodedQuery, err := url.QueryUnescape(searchQuery) if err != nil { @@ -91,7 +93,7 @@ func main() { log.Printf("Decoded Query: %s", decodedQuery) - responseList, err := weaviate.ResponseList(decodedQuery) + responseList, err := weaviate.ResponseList(decodedQuery, instructType) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -130,6 +132,24 @@ func main() { c.JSON(http.StatusOK, response) }) + router.GET("/get-similar-meaning", func(c *gin.Context) { + searchQuery := c.Query("meaning") + + decodedQuery, err := url.QueryUnescape(searchQuery) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid search query"}) + return + } + + response, err := SemanticSimilarityByMeaning(decodedQuery) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, response) + }) + router.GET("/get-similar-code", func(c *gin.Context) { searchQuery := c.Query("code") @@ -139,7 +159,25 @@ func main() { return } - response, err := SemanticSimilarity(decodedQuery) + response, err := SemanticSimilarityByCode(decodedQuery) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, response) + }) + + router.GET("/get-instructtype", func(c *gin.Context) { + searchQuery := c.Query("code") + + decodedQuery, err := url.QueryUnescape(searchQuery) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid search query"}) + return + } + + response, err := weaviate.GetInstructTypes(decodedQuery) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -217,6 +255,16 @@ func main() { router.POST("/add-instruct", redis.AddInstruct) router.POST("/del-instruct", redis.DeleteInstruct) router.GET("/get-all-sets", redis.GetAllSets) + router.GET("/delete-db", func(c *gin.Context) { + secretkey := c.Query("key") + + if secretkey == os.Getenv("delete_key") { + ResetDB() + c.JSON(http.StatusOK, "OK") + } else { + c.JSON(http.StatusUnauthorized, "Unauthorized") + } + }) err = router.Run(":8080") if err != nil { @@ -224,7 +272,7 @@ func main() { } } -func SemanticSimilarity(code string) ([]string, error) { +func SemanticSimilarityByCode(code string) ([]string, error) { PromptExists, exists := weaviate.RetrieveHasSemanticMeaning(code) if !exists { SemanticMeaning := ollama.SemanticMeaning("", code, false) @@ -242,5 +290,22 @@ func SemanticSimilarity(code string) ([]string, error) { return similarCode, err } +} + +func SemanticSimilarityByMeaning(meaning string) ([]string, error) { + similarCode, err := weaviate.GetSimilarSemanticMeaning(meaning) + if err != nil { + return nil, err + } + return similarCode, err +} +func ResetDB() { + redis.DeleteAllSets() + redis.InitRedis() + weaviate.DeleteAllClasses() + err := weaviate.InitSchema() + if err != nil { + return + } } diff --git a/ollama/ollama.go b/ollama/ollama.go index 74e5ae1..1cd2add 100644 --- a/ollama/ollama.go +++ b/ollama/ollama.go @@ -111,7 +111,7 @@ func GenerateResponse(prompt map[string]interface{}) (weaviate.ResponseData, err return weaviate.ResponseData{}, errors.New("invalid response format") } - PromptID, err := weaviate.CreatePromptObject(instruct, code, "Prompt", gitURL) + PromptID, err := weaviate.CreatePromptObject(instruct, set, code, "Prompt", gitURL) if err != nil { return weaviate.ResponseData{}, err } diff --git a/redis/redis.go b/redis/redis.go index 913658b..b8ec388 100644 --- a/redis/redis.go +++ b/redis/redis.go @@ -265,3 +265,27 @@ func GetAllSets(c *gin.Context) { return } + +func DeleteAllSets() { + rdb := loadClient() + + keysCmd := rdb.Keys(context.Background(), "*") // Get all keys matching the pattern "*" + + keys, err := keysCmd.Result() + if err != nil { + return + } + + for _, key := range keys { + typeCmd := rdb.Type(context.Background(), key) // Get the type of the key + + keyType, err := typeCmd.Result() + if err != nil { + return + } + + if keyType == "set" { + rdb.Del(context.Background(), key) + } + } +} diff --git a/weaviate/weaviate.go b/weaviate/weaviate.go index 54924f5..de849e2 100644 --- a/weaviate/weaviate.go +++ b/weaviate/weaviate.go @@ -129,6 +129,16 @@ func InitSchema() error { }, }, }, + { + DataType: []string{"text"}, + Description: "instruct type", + Name: "instructType", + ModuleConfig: map[string]interface{}{ + "text2vec-transformers": map[string]interface{}{ + "skip": true, + }, + }, + }, { DataType: []string{"text"}, Description: "The code which is targeted in the prompt", @@ -226,17 +236,18 @@ func createClass(className, description, vectorizer string, properties []*models return nil } -func CreatePromptObject(instruct string, code string, class string, gitURL string) (string, error) { +func CreatePromptObject(instruct string, instructType string, code string, class string, gitURL string) (string, error) { client, err := loadClient() if err != nil { return "", err } dataSchema := map[string]interface{}{ - "instruct": instruct, - "code": code, - "rank": 1, - "gitURL": gitURL, + "instruct": instruct, + "code": code, + "rank": 1, + "gitURL": gitURL, + "instructType": instructType, } weaviateObject, err := client.Data().Creator(). @@ -438,3 +449,18 @@ func CreateReferenceSemanticMeaningToPrompt(semanticMeaningID string, PromptID s return nil } + +func DeleteAllClasses() { + client, err := loadClient() + if err != nil { + log.Printf("Error loading client: %v\n", err) + return + } + + classes := []string{"Prompt", "Response", "SemanticMeaning"} + + for _, ch := range classes { + err = client.Schema().ClassDeleter().WithClassName(ch).Do(context.Background()) + + } +} diff --git a/weaviate/weaviate_retrieval.go b/weaviate/weaviate_retrieval.go index 9aed2e0..0ece8fe 100644 --- a/weaviate/weaviate_retrieval.go +++ b/weaviate/weaviate_retrieval.go @@ -7,6 +7,7 @@ import ( "fmt" "log" "math/rand" + "strings" "time" "github.com/weaviate/weaviate-go-client/v4/weaviate/filters" @@ -49,24 +50,34 @@ func RetrieveProperties(id string) (PromptProperties, error) { return PromptProperties{}, err } - var responseText string - if len(temp.HasResponse) > 0 { - responseTextBytes, err := json.Marshal(temp.HasResponse) - if err != nil { - return PromptProperties{}, err - } - responseText = string(responseTextBytes) + responseID, err := extractUUIDFromHasResponse(temp.HasResponse) + if err != nil { + return PromptProperties{}, err } - response, err := RetrieveResponseByID(id) + objects, err = client.Data().ObjectsGetter(). + WithID(responseID). + WithClassName("Response"). + Do(context.Background()) if err != nil { return PromptProperties{}, err } - responseText, ok := response.(string) - if !ok { - return PromptProperties{}, fmt.Errorf("response from RetrieveResponseByID is not a string") + + propertiesJSON, err = json.Marshal(objects[0].Properties) + if err != nil { + return PromptProperties{}, err } + var responseTemp struct { + Response string `json:"response"` + } + + if err := json.Unmarshal(propertiesJSON, &responseTemp); err != nil { + return PromptProperties{}, err + } + + responseText := responseTemp.Response + promptProperties := PromptProperties{ Code: temp.Code, HasResponse: responseText, @@ -78,6 +89,17 @@ func RetrieveProperties(id string) (PromptProperties, error) { return promptProperties, nil } +func extractUUIDFromHasResponse(hasResponse []map[string]interface{}) (string, error) { + for _, response := range hasResponse { + if beacon, ok := response["beacon"].(string); ok { + // Split the beacon string to extract the UUID + uuid := beacon[strings.LastIndex(beacon, "/")+1:] + return uuid, nil + } + } + return "", fmt.Errorf("no UUID found in hasResponse field") +} + func RetrievePromptCount(code string) (int, error) { client, err := loadClient() if err != nil { @@ -184,8 +206,8 @@ func RetrieveResponseByID(id string) (interface{}, error) { return response, nil } -func ResponseList(code string) ([]string, error) { - responses, err := RetrieveResponsesRankDesc(code) +func ResponseList(code string, instructtype string) ([]string, error) { + responses, err := RetrieveResponsesRankDesc(code, instructtype) if err != nil { return nil, err } @@ -225,7 +247,7 @@ func ResponseList(code string) ([]string, error) { func RetrieveBestResponse(code string) (ResponseData, error) { - responses, err := RetrieveResponsesRankDesc(code) + responses, err := RetrieveResponsesRankDesc(code, "") if err != nil { return ResponseData{}, err } @@ -307,7 +329,7 @@ func RetrieveBestResponse(code string) (ResponseData, error) { func RetrieveRandomResponse(code string) (ResponseData, error) { - responses, err := RetrieveResponsesRankDesc(code) + responses, err := RetrieveResponsesRankDesc(code, "") if err != nil { return ResponseData{}, err } @@ -356,7 +378,7 @@ func RetrieveRandomResponse(code string) (ResponseData, error) { return responseData, nil } -func RetrieveResponsesRankDesc(code string) (*models.GraphQLResponse, error) { +func RetrieveResponsesRankDesc(code string, instructType string) (*models.GraphQLResponse, error) { client, err := loadClient() if err != nil { @@ -376,10 +398,22 @@ func RetrieveResponsesRankDesc(code string) (*models.GraphQLResponse, error) { {Name: "instruct"}, } + if instructType == "" { + instructType = "*" // wildcard to match all instruct types + } + where := filters.Where(). - WithPath([]string{"code"}). - WithOperator(filters.Like). - WithValueText(code) + WithOperator(filters.And). + WithOperands([]*filters.WhereBuilder{ + filters.Where(). + WithPath([]string{"code"}). + WithOperator(filters.Like). + WithValueText(code), + filters.Where(). + WithPath([]string{"instructType"}). + WithOperator(filters.Equal). + WithValueText(instructType), + }) rankDesc := graphql.Sort{ Path: []string{"rank"}, Order: graphql.Desc, @@ -533,7 +567,7 @@ func RetrieveHasSemanticMeaning(code string) (string, bool) { return semanticMeaning, true } -func GetSimilarSemanticMeaning(code string) ([]string, error) { +func GetSimilarSemanticMeaning(meaning string) ([]string, error) { client, err := loadClient() if err != nil { log.Printf("Error loading client: %v", err) @@ -551,7 +585,7 @@ func GetSimilarSemanticMeaning(code string) ([]string, error) { } withNearText := client.GraphQL().NearTextArgBuilder(). - WithConcepts([]string{code}). + WithConcepts([]string{meaning}). WithCertainty(0.8) result, err := client.GraphQL().Get(). @@ -598,3 +632,66 @@ func GetSimilarSemanticMeaning(code string) ([]string, error) { return gitURLs, nil } + +func GetInstructTypes(code string) ([]string, error) { + client, err := loadClient() + if err != nil { + log.Printf("Error loading client: %v", err) + } + + fields := []graphql.Field{ + {Name: "instructType"}, + } + + where := filters.Where(). + WithPath([]string{"code"}). + WithOperator(filters.Like). + WithValueText(code) + + result, err := client.GraphQL().Get(). + WithClassName("Prompt"). + WithFields(fields...). + WithWhere(where). + Do(context.Background()) + if err != nil { + return nil, err + } + + uniqueExplanationStrings := ExtractExplanationStrings(result) + + log.Printf("uniqueExplanationStrings: %v", uniqueExplanationStrings) + return uniqueExplanationStrings, nil +} + +func ExtractExplanationStrings(result *models.GraphQLResponse) []string { + var explanationStrings []string + + getMap, ok := result.Data["Get"].(map[string]interface{}) + if !ok { + return explanationStrings + } + + promptList, ok := getMap["Prompt"].([]interface{}) + if !ok { + return explanationStrings + } + + explanationSet := make(map[string]struct{}) + for _, prompt := range promptList { + promptMap, ok := prompt.(map[string]interface{}) + if !ok { + continue + } + explanation, ok := promptMap["instructType"].(string) + if !ok { + continue + } + explanationSet[explanation] = struct{}{} + } + + for explanation := range explanationSet { + explanationStrings = append(explanationStrings, explanation) + } + + return explanationStrings +}