Skip to content

Commit

Permalink
Remove deprecated RAG components and update Go module dependencies. T…
Browse files Browse the repository at this point in the history
…his commit deletes the `rag/types.go`, `rag/openai/vectorizer.go`, and `rag/qdrant` files, which were no longer in use, streamlining the codebase. Additionally, the Go module has been updated to version 1.22.2, ensuring compatibility with the latest features and improvements.
  • Loading branch information
trheyi committed Jan 2, 2025
1 parent edb1237 commit 1ea1b1b
Show file tree
Hide file tree
Showing 10 changed files with 172 additions and 88 deletions.
2 changes: 0 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ module github.com/yaoapp/gou

go 1.22.2

toolchain go1.23.4

require (
github.com/blang/semver/v4 v4.0.0 // indirect
github.com/fatih/color v1.16.0
Expand Down
8 changes: 4 additions & 4 deletions rag/openai/vectorizer.go → rag/driver/openai/vectorizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"time"
)

// Vectorizer implements rag.Vectorizer using OpenAI's embeddings API
// Vectorizer implements driver.Vectorizer using OpenAI's embeddings API
type Vectorizer struct {
apiKey string
model string
Expand Down Expand Up @@ -59,7 +59,7 @@ func New(config Config) (*Vectorizer, error) {
}, nil
}

// Vectorize implements rag.Vectorizer
// Vectorize implements driver.Vectorizer
func (v *Vectorizer) Vectorize(ctx context.Context, text string) ([]float32, error) {
// Prepare request body
reqBody := map[string]interface{}{
Expand Down Expand Up @@ -105,7 +105,7 @@ func (v *Vectorizer) Vectorize(ctx context.Context, text string) ([]float32, err
return embedResp.Data[0].Embedding, nil
}

// VectorizeBatch implements rag.Vectorizer
// VectorizeBatch implements driver.Vectorizer
func (v *Vectorizer) VectorizeBatch(ctx context.Context, texts []string) ([][]float32, error) {
// Prepare request body
reqBody := map[string]interface{}{
Expand Down Expand Up @@ -157,7 +157,7 @@ func (v *Vectorizer) VectorizeBatch(ctx context.Context, texts []string) ([][]fl
return embeddings, nil
}

// Close implements rag.Vectorizer
// Close implements driver.Vectorizer
func (v *Vectorizer) Close() error {
return nil
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package openai
package openai_test

import (
"context"
"os"
"testing"

"github.com/yaoapp/gou/rag/driver/openai"
)

func TestOpenAIVectorizer(t *testing.T) {
Expand All @@ -13,7 +15,7 @@ func TestOpenAIVectorizer(t *testing.T) {
}

// Create vectorizer
vectorizer, err := New(Config{
vectorizer, err := openai.New(openai.Config{
APIKey: apiKey,
Model: "text-embedding-ada-002",
})
Expand Down
File renamed without changes.
45 changes: 21 additions & 24 deletions rag/qdrant/engine.go → rag/driver/qdrant/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@ import (
"time"

"github.com/qdrant/go-client/qdrant"
"github.com/yaoapp/gou/rag"
"github.com/yaoapp/gou/rag/driver"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
)

// Engine implements the rag.Engine interface using Qdrant as the vector store backend
// Engine implements the driver.Engine interface using Qdrant as the vector store backend
type Engine struct {
client *qdrant.Client
vectorizer rag.Vectorizer
vectorizer driver.Vectorizer
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
Expand All @@ -33,7 +33,7 @@ type Config struct {
Host string // Host address of Qdrant server, e.g., "localhost"
Port uint32 // Port number of Qdrant server, default is 6334 for gRPC
APIKey string // Optional API key for authentication
Vectorizer rag.Vectorizer
Vectorizer driver.Vectorizer
}

// NewEngine creates a new instance of the Qdrant engine with the given configuration
Expand Down Expand Up @@ -74,7 +74,7 @@ func NewEngine(config Config) (*Engine, error) {
}

// CreateIndex creates a new vector collection in Qdrant with the given configuration
func (e *Engine) CreateIndex(ctx context.Context, config rag.IndexConfig) error {
func (e *Engine) CreateIndex(ctx context.Context, config driver.IndexConfig) error {
// Get vector dimension from vectorizer
dims, err := e.getVectorDimension()
if err != nil {
Expand Down Expand Up @@ -116,7 +116,7 @@ func (e *Engine) ListIndexes(ctx context.Context) ([]string, error) {
}

// IndexDoc adds or updates a document in the specified vector collection
func (e *Engine) IndexDoc(ctx context.Context, indexName string, doc *rag.Document) error {
func (e *Engine) IndexDoc(ctx context.Context, indexName string, doc *driver.Document) error {
if err := e.checkContext(ctx); err != nil {
return err
}
Expand Down Expand Up @@ -168,7 +168,7 @@ func (e *Engine) IndexDoc(ctx context.Context, indexName string, doc *rag.Docume
}

// Search performs a vector similarity search in the specified collection
func (e *Engine) Search(ctx context.Context, indexName string, vector []float32, opts rag.VectorSearchOptions) ([]rag.SearchResult, error) {
func (e *Engine) Search(ctx context.Context, indexName string, vector []float32, opts driver.VectorSearchOptions) ([]driver.SearchResult, error) {
if err := e.checkContext(ctx); err != nil {
return nil, err
}
Expand Down Expand Up @@ -202,7 +202,7 @@ func (e *Engine) Search(ctx context.Context, indexName string, vector []float32,
return nil, fmt.Errorf("failed to search: %w", err)
}

results := make([]rag.SearchResult, len(points))
results := make([]driver.SearchResult, len(points))
for i, point := range points {
content := point.Payload["content"].GetStringValue()
var metadata map[string]interface{}
Expand All @@ -212,7 +212,7 @@ func (e *Engine) Search(ctx context.Context, indexName string, vector []float32,
}
}

results[i] = rag.SearchResult{
results[i] = driver.SearchResult{
DocID: point.Id.GetUuid(),
Score: float64(point.Score),
Content: content,
Expand All @@ -223,7 +223,7 @@ func (e *Engine) Search(ctx context.Context, indexName string, vector []float32,
}

// GetDocument retrieves a document by its ID from the specified collection
func (e *Engine) GetDocument(ctx context.Context, indexName string, DocID string) (*rag.Document, error) {
func (e *Engine) GetDocument(ctx context.Context, indexName string, DocID string) (*driver.Document, error) {
points, err := e.client.Get(ctx, &qdrant.GetPoints{
CollectionName: indexName,
Ids: []*qdrant.PointId{qdrant.NewID(DocID)},
Expand All @@ -250,7 +250,7 @@ func (e *Engine) GetDocument(ctx context.Context, indexName string, DocID string
}
}

return &rag.Document{
return &driver.Document{
DocID: DocID,
Content: content,
Metadata: metadata,
Expand Down Expand Up @@ -307,14 +307,11 @@ func (e *Engine) Close() error {

// Return combined errors if any
if len(errs) > 0 {
var errMsg string
errMsgs := make([]string, len(errs))
for i, err := range errs {
if i > 0 {
errMsg += "; "
}
errMsg += err.Error()
errMsgs[i] = err.Error()
}
return fmt.Errorf(errMsg)
return fmt.Errorf("multiple errors: %s", strings.Join(errMsgs, "; "))
}

return nil
Expand Down Expand Up @@ -352,7 +349,7 @@ func convertStructToMap(s *qdrant.Struct) map[string]interface{} {
}

// IndexBatch adds or updates multiple documents in batch
func (e *Engine) IndexBatch(ctx context.Context, indexName string, docs []*rag.Document) (string, error) {
func (e *Engine) IndexBatch(ctx context.Context, indexName string, docs []*driver.Document) (string, error) {
if err := e.checkContext(ctx); err != nil {
return "", err
}
Expand Down Expand Up @@ -474,7 +471,7 @@ func (e *Engine) DeleteBatch(ctx context.Context, indexName string, DocIDs []str
}

// GetTaskInfo retrieves information about an asynchronous task
func (e *Engine) GetTaskInfo(ctx context.Context, taskID string) (*rag.TaskInfo, error) {
func (e *Engine) GetTaskInfo(ctx context.Context, taskID string) (*driver.TaskInfo, error) {
// Parse the task type and timestamp from the taskID
parts := strings.Split(taskID, "-")
if len(parts) != 3 || (parts[0] != "batch" && parts[1] != "index" && parts[1] != "delete") {
Expand All @@ -488,12 +485,12 @@ func (e *Engine) GetTaskInfo(ctx context.Context, taskID string) (*rag.TaskInfo,

// For now, we'll consider tasks as completed after 5 seconds
elapsed := time.Since(time.Unix(0, timestamp))
status := rag.StatusRunning
status := driver.StatusRunning
if elapsed > 5*time.Second {
status = rag.StatusComplete
status = driver.StatusComplete
}

return &rag.TaskInfo{
return &driver.TaskInfo{
TaskID: taskID,
Status: status,
Created: timestamp,
Expand All @@ -505,10 +502,10 @@ func (e *Engine) GetTaskInfo(ctx context.Context, taskID string) (*rag.TaskInfo,
}

// ListTasks returns a list of all tasks for the specified collection
func (e *Engine) ListTasks(ctx context.Context, indexName string) ([]*rag.TaskInfo, error) {
func (e *Engine) ListTasks(ctx context.Context, indexName string) ([]*driver.TaskInfo, error) {
// Since Qdrant doesn't provide direct task listing,
// we'll return an empty list for now
return []*rag.TaskInfo{}, nil
return []*driver.TaskInfo{}, nil
}

// CancelTask cancels an ongoing asynchronous task
Expand Down
Loading

0 comments on commit 1ea1b1b

Please sign in to comment.