diff --git a/go.mod b/go.mod index 763d61d..8e9c735 100644 --- a/go.mod +++ b/go.mod @@ -2,8 +2,6 @@ module github.com/yaoapp/gou go 1.22.2 -toolchain go1.23.4 - require ( github.com/blang/semver/v4 v4.0.0 // indirect github.com/fatih/color v1.16.0 diff --git a/rag/openai/vectorizer.go b/rag/driver/openai/vectorizer.go similarity index 95% rename from rag/openai/vectorizer.go rename to rag/driver/openai/vectorizer.go index 41ff4dd..2e0e4a3 100644 --- a/rag/openai/vectorizer.go +++ b/rag/driver/openai/vectorizer.go @@ -10,7 +10,7 @@ import ( "time" ) -// Vectorizer implements rag.Vectorizer using OpenAI's embeddings API +// Vectorizer implements driver.Vectorizer using OpenAI's embeddings API type Vectorizer struct { apiKey string model string @@ -59,7 +59,7 @@ func New(config Config) (*Vectorizer, error) { }, nil } -// Vectorize implements rag.Vectorizer +// Vectorize implements driver.Vectorizer func (v *Vectorizer) Vectorize(ctx context.Context, text string) ([]float32, error) { // Prepare request body reqBody := map[string]interface{}{ @@ -105,7 +105,7 @@ func (v *Vectorizer) Vectorize(ctx context.Context, text string) ([]float32, err return embedResp.Data[0].Embedding, nil } -// VectorizeBatch implements rag.Vectorizer +// VectorizeBatch implements driver.Vectorizer func (v *Vectorizer) VectorizeBatch(ctx context.Context, texts []string) ([][]float32, error) { // Prepare request body reqBody := map[string]interface{}{ @@ -157,7 +157,7 @@ func (v *Vectorizer) VectorizeBatch(ctx context.Context, texts []string) ([][]fl return embeddings, nil } -// Close implements rag.Vectorizer +// Close implements driver.Vectorizer func (v *Vectorizer) Close() error { return nil } diff --git a/rag/openai/vectorizer_test.go b/rag/driver/openai/vectorizer_test.go similarity index 96% rename from rag/openai/vectorizer_test.go rename to rag/driver/openai/vectorizer_test.go index b8bd672..9ce2298 100644 --- a/rag/openai/vectorizer_test.go +++ b/rag/driver/openai/vectorizer_test.go @@ -1,9 +1,11 @@ -package openai +package openai_test import ( "context" "os" "testing" + + "github.com/yaoapp/gou/rag/driver/openai" ) func TestOpenAIVectorizer(t *testing.T) { @@ -13,7 +15,7 @@ func TestOpenAIVectorizer(t *testing.T) { } // Create vectorizer - vectorizer, err := New(Config{ + vectorizer, err := openai.New(openai.Config{ APIKey: apiKey, Model: "text-embedding-ada-002", }) diff --git a/rag/qdrant/README.md b/rag/driver/qdrant/README.md similarity index 100% rename from rag/qdrant/README.md rename to rag/driver/qdrant/README.md diff --git a/rag/qdrant/engine.go b/rag/driver/qdrant/engine.go similarity index 93% rename from rag/qdrant/engine.go rename to rag/driver/qdrant/engine.go index 66cda97..ca2ac7c 100644 --- a/rag/qdrant/engine.go +++ b/rag/driver/qdrant/engine.go @@ -9,16 +9,16 @@ import ( "time" "github.com/qdrant/go-client/qdrant" - "github.com/yaoapp/gou/rag" + "github.com/yaoapp/gou/rag/driver" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" ) -// Engine implements the rag.Engine interface using Qdrant as the vector store backend +// Engine implements the driver.Engine interface using Qdrant as the vector store backend type Engine struct { client *qdrant.Client - vectorizer rag.Vectorizer + vectorizer driver.Vectorizer ctx context.Context cancel context.CancelFunc wg sync.WaitGroup @@ -33,7 +33,7 @@ type Config struct { Host string // Host address of Qdrant server, e.g., "localhost" Port uint32 // Port number of Qdrant server, default is 6334 for gRPC APIKey string // Optional API key for authentication - Vectorizer rag.Vectorizer + Vectorizer driver.Vectorizer } // NewEngine creates a new instance of the Qdrant engine with the given configuration @@ -74,7 +74,7 @@ func NewEngine(config Config) (*Engine, error) { } // CreateIndex creates a new vector collection in Qdrant with the given configuration -func (e *Engine) CreateIndex(ctx context.Context, config rag.IndexConfig) error { +func (e *Engine) CreateIndex(ctx context.Context, config driver.IndexConfig) error { // Get vector dimension from vectorizer dims, err := e.getVectorDimension() if err != nil { @@ -116,7 +116,7 @@ func (e *Engine) ListIndexes(ctx context.Context) ([]string, error) { } // IndexDoc adds or updates a document in the specified vector collection -func (e *Engine) IndexDoc(ctx context.Context, indexName string, doc *rag.Document) error { +func (e *Engine) IndexDoc(ctx context.Context, indexName string, doc *driver.Document) error { if err := e.checkContext(ctx); err != nil { return err } @@ -168,7 +168,7 @@ func (e *Engine) IndexDoc(ctx context.Context, indexName string, doc *rag.Docume } // Search performs a vector similarity search in the specified collection -func (e *Engine) Search(ctx context.Context, indexName string, vector []float32, opts rag.VectorSearchOptions) ([]rag.SearchResult, error) { +func (e *Engine) Search(ctx context.Context, indexName string, vector []float32, opts driver.VectorSearchOptions) ([]driver.SearchResult, error) { if err := e.checkContext(ctx); err != nil { return nil, err } @@ -202,7 +202,7 @@ func (e *Engine) Search(ctx context.Context, indexName string, vector []float32, return nil, fmt.Errorf("failed to search: %w", err) } - results := make([]rag.SearchResult, len(points)) + results := make([]driver.SearchResult, len(points)) for i, point := range points { content := point.Payload["content"].GetStringValue() var metadata map[string]interface{} @@ -212,7 +212,7 @@ func (e *Engine) Search(ctx context.Context, indexName string, vector []float32, } } - results[i] = rag.SearchResult{ + results[i] = driver.SearchResult{ DocID: point.Id.GetUuid(), Score: float64(point.Score), Content: content, @@ -223,7 +223,7 @@ func (e *Engine) Search(ctx context.Context, indexName string, vector []float32, } // GetDocument retrieves a document by its ID from the specified collection -func (e *Engine) GetDocument(ctx context.Context, indexName string, DocID string) (*rag.Document, error) { +func (e *Engine) GetDocument(ctx context.Context, indexName string, DocID string) (*driver.Document, error) { points, err := e.client.Get(ctx, &qdrant.GetPoints{ CollectionName: indexName, Ids: []*qdrant.PointId{qdrant.NewID(DocID)}, @@ -250,7 +250,7 @@ func (e *Engine) GetDocument(ctx context.Context, indexName string, DocID string } } - return &rag.Document{ + return &driver.Document{ DocID: DocID, Content: content, Metadata: metadata, @@ -307,14 +307,11 @@ func (e *Engine) Close() error { // Return combined errors if any if len(errs) > 0 { - var errMsg string + errMsgs := make([]string, len(errs)) for i, err := range errs { - if i > 0 { - errMsg += "; " - } - errMsg += err.Error() + errMsgs[i] = err.Error() } - return fmt.Errorf(errMsg) + return fmt.Errorf("multiple errors: %s", strings.Join(errMsgs, "; ")) } return nil @@ -352,7 +349,7 @@ func convertStructToMap(s *qdrant.Struct) map[string]interface{} { } // IndexBatch adds or updates multiple documents in batch -func (e *Engine) IndexBatch(ctx context.Context, indexName string, docs []*rag.Document) (string, error) { +func (e *Engine) IndexBatch(ctx context.Context, indexName string, docs []*driver.Document) (string, error) { if err := e.checkContext(ctx); err != nil { return "", err } @@ -474,7 +471,7 @@ func (e *Engine) DeleteBatch(ctx context.Context, indexName string, DocIDs []str } // GetTaskInfo retrieves information about an asynchronous task -func (e *Engine) GetTaskInfo(ctx context.Context, taskID string) (*rag.TaskInfo, error) { +func (e *Engine) GetTaskInfo(ctx context.Context, taskID string) (*driver.TaskInfo, error) { // Parse the task type and timestamp from the taskID parts := strings.Split(taskID, "-") if len(parts) != 3 || (parts[0] != "batch" && parts[1] != "index" && parts[1] != "delete") { @@ -488,12 +485,12 @@ func (e *Engine) GetTaskInfo(ctx context.Context, taskID string) (*rag.TaskInfo, // For now, we'll consider tasks as completed after 5 seconds elapsed := time.Since(time.Unix(0, timestamp)) - status := rag.StatusRunning + status := driver.StatusRunning if elapsed > 5*time.Second { - status = rag.StatusComplete + status = driver.StatusComplete } - return &rag.TaskInfo{ + return &driver.TaskInfo{ TaskID: taskID, Status: status, Created: timestamp, @@ -505,10 +502,10 @@ func (e *Engine) GetTaskInfo(ctx context.Context, taskID string) (*rag.TaskInfo, } // ListTasks returns a list of all tasks for the specified collection -func (e *Engine) ListTasks(ctx context.Context, indexName string) ([]*rag.TaskInfo, error) { +func (e *Engine) ListTasks(ctx context.Context, indexName string) ([]*driver.TaskInfo, error) { // Since Qdrant doesn't provide direct task listing, // we'll return an empty list for now - return []*rag.TaskInfo{}, nil + return []*driver.TaskInfo{}, nil } // CancelTask cancels an ongoing asynchronous task diff --git a/rag/qdrant/engine_test.go b/rag/driver/qdrant/engine_test.go similarity index 91% rename from rag/qdrant/engine_test.go rename to rag/driver/qdrant/engine_test.go index 1b26b8d..a9f5079 100644 --- a/rag/qdrant/engine_test.go +++ b/rag/driver/qdrant/engine_test.go @@ -11,8 +11,8 @@ import ( "time" "github.com/stretchr/testify/assert" - "github.com/yaoapp/gou/rag" - "github.com/yaoapp/gou/rag/openai" + "github.com/yaoapp/gou/rag/driver" + "github.com/yaoapp/gou/rag/driver/openai" ) func getTestConfig(t *testing.T) Config { @@ -63,7 +63,7 @@ func TestBasicOperations(t *testing.T) { // Test index operations indexName := "test_index" - err = engine.CreateIndex(ctx, rag.IndexConfig{Name: indexName}) + err = engine.CreateIndex(ctx, driver.IndexConfig{Name: indexName}) assert.NoError(t, err) // List indexes @@ -72,7 +72,7 @@ func TestBasicOperations(t *testing.T) { assert.Contains(t, indexes, indexName) // Test document operations - doc := &rag.Document{ + doc := &driver.Document{ DocID: "123e4567-e89b-12d3-a456-426614174000", Content: "This is a test document for Qdrant vector search.", Metadata: map[string]interface{}{"type": "test", "version": 1.0}, @@ -90,7 +90,7 @@ func TestBasicOperations(t *testing.T) { assert.Equal(t, doc.Metadata["version"], retrieved.Metadata["version"]) // Search - searchOpts := rag.VectorSearchOptions{ + searchOpts := driver.VectorSearchOptions{ QueryText: "test document", TopK: 5, } @@ -122,17 +122,17 @@ func TestBatchOperations(t *testing.T) { defer engine.Close() indexName := "test_batch_index" - err = engine.CreateIndex(ctx, rag.IndexConfig{Name: indexName}) + err = engine.CreateIndex(ctx, driver.IndexConfig{Name: indexName}) assert.NoError(t, err) defer engine.DeleteIndex(ctx, indexName) // Prepare batch documents - docs := make([]*rag.Document, 10) + docs := make([]*driver.Document, 10) docIDs := make([]string, 10) for i := 0; i < 10; i++ { // Use proper UUID format docID := fmt.Sprintf("00000000-0000-0000-0000-%012d", i) - docs[i] = &rag.Document{ + docs[i] = &driver.Document{ DocID: docID, Content: fmt.Sprintf("This is test document %d", i), Metadata: map[string]interface{}{"index": i}, @@ -148,18 +148,18 @@ func TestBatchOperations(t *testing.T) { assert.NotEmpty(t, taskID) // Wait for indexing to complete and verify with retries - var taskInfo *rag.TaskInfo + var taskInfo *driver.TaskInfo for i := 0; i < 10; i++ { taskInfo, err = engine.GetTaskInfo(ctx, taskID) if err != nil { t.Fatalf("Failed to get task info: %v", err) } - if taskInfo.Status == rag.StatusComplete { + if taskInfo.Status == driver.StatusComplete { break } time.Sleep(time.Second) } - assert.Equal(t, rag.StatusComplete, taskInfo.Status) + assert.Equal(t, driver.StatusComplete, taskInfo.Status) // Wait a bit more to ensure documents are fully indexed time.Sleep(2 * time.Second) @@ -182,18 +182,18 @@ func TestBatchOperations(t *testing.T) { assert.NotEmpty(t, deleteTaskID) // Wait for deletion to complete and verify with retries - var deleteTaskInfo *rag.TaskInfo + var deleteTaskInfo *driver.TaskInfo for i := 0; i < 10; i++ { deleteTaskInfo, err = engine.GetTaskInfo(ctx, deleteTaskID) if err != nil { t.Fatalf("Failed to get delete task info: %v", err) } - if deleteTaskInfo.Status == rag.StatusComplete { + if deleteTaskInfo.Status == driver.StatusComplete { break } time.Sleep(time.Second) } - assert.Equal(t, rag.StatusComplete, deleteTaskInfo.Status) + assert.Equal(t, driver.StatusComplete, deleteTaskInfo.Status) // Wait a bit more to ensure documents are fully deleted time.Sleep(2 * time.Second) @@ -215,16 +215,16 @@ func TestTaskManagement(t *testing.T) { defer engine.Close() indexName := "test_task_index" - err = engine.CreateIndex(ctx, rag.IndexConfig{Name: indexName}) + err = engine.CreateIndex(ctx, driver.IndexConfig{Name: indexName}) assert.NoError(t, err) defer engine.DeleteIndex(ctx, indexName) // Create a batch operation to get a task ID - docs := make([]*rag.Document, 5) + docs := make([]*driver.Document, 5) for i := 0; i < 5; i++ { // Use proper UUID format docID := fmt.Sprintf("00000000-0000-0000-0000-%012d", i) - docs[i] = &rag.Document{ + docs[i] = &driver.Document{ DocID: docID, Content: fmt.Sprintf("Test document for task management %d", i), } @@ -237,7 +237,7 @@ func TestTaskManagement(t *testing.T) { assert.NotEmpty(t, taskID) // Test GetTaskInfo with retries - var taskInfo *rag.TaskInfo + var taskInfo *driver.TaskInfo var getTaskErr error for i := 0; i < 5; i++ { taskInfo, getTaskErr = engine.GetTaskInfo(ctx, taskID) @@ -305,7 +305,7 @@ func TestResourceLeaks(t *testing.T) { for i := 0; i < 2; i++ { // Further reduce iterations func() { indexName := fmt.Sprintf("test_leak_index_%d", i) - err = engine.CreateIndex(ctx, rag.IndexConfig{Name: indexName}) + err = engine.CreateIndex(ctx, driver.IndexConfig{Name: indexName}) assert.NoError(t, err) // Ensure index is deleted even if test fails @@ -318,7 +318,7 @@ func TestResourceLeaks(t *testing.T) { // Use proper UUID format docID := fmt.Sprintf("00000000-0000-0000-0000-%012d", i) - doc := &rag.Document{ + doc := &driver.Document{ DocID: docID, Content: "Test document for leak detection", } @@ -329,7 +329,7 @@ func TestResourceLeaks(t *testing.T) { time.Sleep(time.Second) // Search operations - searchOpts := rag.VectorSearchOptions{ + searchOpts := driver.VectorSearchOptions{ QueryText: "test document", TopK: 5, } @@ -384,7 +384,7 @@ func TestConcurrentOperations(t *testing.T) { // Ensure index name is unique indexName := fmt.Sprintf("test_concurrent_index_%d", time.Now().UnixNano()) - err = engine.CreateIndex(ctx, rag.IndexConfig{Name: indexName}) + err = engine.CreateIndex(ctx, driver.IndexConfig{Name: indexName}) assert.NoError(t, err) // Use sync.Once to ensure cleanup happens only once @@ -425,7 +425,7 @@ func TestConcurrentOperations(t *testing.T) { defer wg.Done() // Use proper UUID format docID := fmt.Sprintf("00000000-0000-0000-0000-%012d", i) - doc := &rag.Document{ + doc := &driver.Document{ DocID: docID, Content: fmt.Sprintf("Concurrent test document %d", i), Metadata: map[string]interface{}{"index": i}, @@ -448,7 +448,7 @@ func TestConcurrentOperations(t *testing.T) { for i := 0; i < numOps; i++ { go func(i int) { defer wg.Done() - searchOpts := rag.VectorSearchOptions{ + searchOpts := driver.VectorSearchOptions{ QueryText: fmt.Sprintf("test document %d", i), TopK: 5, } @@ -497,7 +497,7 @@ func TestQdrantEngineErrors(t *testing.T) { // Ensure index name is unique indexName := fmt.Sprintf("test_error_index_%d", time.Now().UnixNano()) - err = engine.CreateIndex(ctx, rag.IndexConfig{Name: indexName}) + err = engine.CreateIndex(ctx, driver.IndexConfig{Name: indexName}) assert.NoError(t, err) // Use defer to ensure proper cleanup order @@ -530,7 +530,7 @@ func TestQdrantEngineErrors(t *testing.T) { assert.Equal(t, "invalid task ID format", err.Error()) // Test batch operations with empty input - _, err = engine.IndexBatch(ctx, indexName, []*rag.Document{}) + _, err = engine.IndexBatch(ctx, indexName, []*driver.Document{}) assert.Error(t, err) assert.Equal(t, "empty document batch", err.Error()) @@ -539,7 +539,7 @@ func TestQdrantEngineErrors(t *testing.T) { assert.Equal(t, "empty batch", err.Error()) // Test operations with nil context - searchOpts := rag.VectorSearchOptions{ + searchOpts := driver.VectorSearchOptions{ QueryText: "test", TopK: 5, } @@ -555,7 +555,7 @@ func TestQdrantEngineErrors(t *testing.T) { assert.Contains(t, err.Error(), "context canceled") // Test invalid vector dimension - invalidDoc := &rag.Document{ + invalidDoc := &driver.Document{ DocID: "00000000-0000-0000-0000-000000000001", Content: "Test document", Embeddings: []float32{0.1, 0.2}, // Invalid dimension diff --git a/rag/qdrant/file.go b/rag/driver/qdrant/file.go similarity index 60% rename from rag/qdrant/file.go rename to rag/driver/qdrant/file.go index b4928b6..8447960 100644 --- a/rag/qdrant/file.go +++ b/rag/driver/qdrant/file.go @@ -8,28 +8,33 @@ import ( "os" "path/filepath" - "github.com/yaoapp/gou/rag" + "github.com/google/uuid" + "github.com/yaoapp/gou/rag/driver" ) -// FileUploader implements the FileUpload interface for Qdrant -type FileUploader struct { - engine *Engine +// FileUpload implements the driver.FileUpload interface for Qdrant +type FileUpload struct { + engine *Engine + vectorizer driver.Vectorizer } -// NewFileUploader creates a new FileUploader instance -func NewFileUploader(engine *Engine) *FileUploader { - return &FileUploader{engine: engine} +// NewFileUpload creates a new FileUpload instance +func NewFileUpload(engine *Engine, vectorizer driver.Vectorizer) (*FileUpload, error) { + return &FileUpload{ + engine: engine, + vectorizer: vectorizer, + }, nil } // Upload processes content from a reader -func (f *FileUploader) Upload(ctx context.Context, reader io.Reader, opts rag.FileUploadOptions) (*rag.FileUploadResult, error) { +func (f *FileUpload) Upload(ctx context.Context, reader io.Reader, opts driver.FileUploadOptions) (*driver.FileUploadResult, error) { if opts.ChunkSize <= 0 { opts.ChunkSize = 1000 // default chunk size } scanner := bufio.NewScanner(reader) var buffer string - var documents []*rag.Document + var documents []*driver.Document docID := 1 // Read content in chunks @@ -38,11 +43,21 @@ func (f *FileUploader) Upload(ctx context.Context, reader io.Reader, opts rag.Fi buffer += line + "\n" if len(buffer) >= opts.ChunkSize { - doc := &rag.Document{ - DocID: fmt.Sprintf("00000000-0000-0000-0000-%012d", docID), + // Generate UUID for the document + docUUID := uuid.New().String() + + // Vectorize the chunk + embeddings, err := f.vectorizer.Vectorize(ctx, buffer) + if err != nil { + return nil, fmt.Errorf("failed to vectorize content chunk %d: %w", docID, err) + } + + doc := &driver.Document{ + DocID: docUUID, Content: buffer, ChunkSize: opts.ChunkSize, ChunkOverlap: opts.ChunkOverlap, + Embeddings: embeddings, Metadata: map[string]interface{}{ "chunk_number": docID, }, @@ -55,11 +70,18 @@ func (f *FileUploader) Upload(ctx context.Context, reader io.Reader, opts rag.Fi // Handle any remaining content if len(buffer) > 0 { - doc := &rag.Document{ - DocID: fmt.Sprintf("00000000-0000-0000-0000-%012d", docID), + docUUID := uuid.New().String() + embeddings, err := f.vectorizer.Vectorize(ctx, buffer) + if err != nil { + return nil, fmt.Errorf("failed to vectorize final content chunk: %w", err) + } + + doc := &driver.Document{ + DocID: docUUID, Content: buffer, ChunkSize: opts.ChunkSize, ChunkOverlap: opts.ChunkOverlap, + Embeddings: embeddings, Metadata: map[string]interface{}{ "chunk_number": docID, }, @@ -71,7 +93,7 @@ func (f *FileUploader) Upload(ctx context.Context, reader io.Reader, opts rag.Fi return nil, fmt.Errorf("error reading content: %w", err) } - result := &rag.FileUploadResult{ + result := &driver.FileUploadResult{ Documents: documents, } @@ -95,7 +117,7 @@ func (f *FileUploader) Upload(ctx context.Context, reader io.Reader, opts rag.Fi } // UploadFile processes content from a file path -func (f *FileUploader) UploadFile(ctx context.Context, path string, opts rag.FileUploadOptions) (*rag.FileUploadResult, error) { +func (f *FileUpload) UploadFile(ctx context.Context, path string, opts driver.FileUploadOptions) (*driver.FileUploadResult, error) { file, err := os.Open(path) if err != nil { return nil, fmt.Errorf("error opening file: %w", err) diff --git a/rag/qdrant/file_test.go b/rag/driver/qdrant/file_test.go similarity index 84% rename from rag/qdrant/file_test.go rename to rag/driver/qdrant/file_test.go index e62db19..dfc844c 100644 --- a/rag/qdrant/file_test.go +++ b/rag/driver/qdrant/file_test.go @@ -9,7 +9,7 @@ import ( "time" "github.com/stretchr/testify/assert" - "github.com/yaoapp/gou/rag" + "github.com/yaoapp/gou/rag/driver" ) func TestFileUploader_Upload(t *testing.T) { @@ -21,25 +21,26 @@ func TestFileUploader_Upload(t *testing.T) { indexName := fmt.Sprintf("test_upload_%d", time.Now().UnixNano()) // Create a test index - err = engine.CreateIndex(context.Background(), rag.IndexConfig{ + err = engine.CreateIndex(context.Background(), driver.IndexConfig{ Name: indexName, Driver: "qdrant", }) assert.NoError(t, err) defer engine.DeleteIndex(context.Background(), indexName) - uploader := NewFileUploader(engine) + uploader, err := NewFileUpload(engine, engine.vectorizer) + assert.NoError(t, err) tests := []struct { name string content string - opts rag.FileUploadOptions + opts driver.FileUploadOptions wantErr bool }{ { name: "Basic upload", content: "This is a test document", - opts: rag.FileUploadOptions{ + opts: driver.FileUploadOptions{ IndexName: indexName, ChunkSize: 100, }, @@ -48,7 +49,7 @@ func TestFileUploader_Upload(t *testing.T) { { name: "Async upload", content: "This is an async test document", - opts: rag.FileUploadOptions{ + opts: driver.FileUploadOptions{ IndexName: indexName, ChunkSize: 100, Async: true, @@ -58,7 +59,7 @@ func TestFileUploader_Upload(t *testing.T) { { name: "Upload with chunks", content: strings.Repeat("Test content ", 100), - opts: rag.FileUploadOptions{ + opts: driver.FileUploadOptions{ IndexName: indexName, ChunkSize: 100, ChunkOverlap: 20, @@ -101,14 +102,15 @@ func TestFileUploader_UploadFile(t *testing.T) { indexName := fmt.Sprintf("test_file_upload_%d", time.Now().UnixNano()) // Create a test index - err = engine.CreateIndex(context.Background(), rag.IndexConfig{ + err = engine.CreateIndex(context.Background(), driver.IndexConfig{ Name: indexName, Driver: "qdrant", }) assert.NoError(t, err) defer engine.DeleteIndex(context.Background(), indexName) - uploader := NewFileUploader(engine) + uploader, err := NewFileUpload(engine, engine.vectorizer) + assert.NoError(t, err) // Create a temporary test file content := "This is a test file content\nWith multiple lines\nFor testing purposes" @@ -123,12 +125,12 @@ func TestFileUploader_UploadFile(t *testing.T) { tests := []struct { name string - opts rag.FileUploadOptions + opts driver.FileUploadOptions wantErr bool }{ { name: "Basic file upload", - opts: rag.FileUploadOptions{ + opts: driver.FileUploadOptions{ IndexName: indexName, ChunkSize: 100, }, @@ -136,7 +138,7 @@ func TestFileUploader_UploadFile(t *testing.T) { }, { name: "Async file upload", - opts: rag.FileUploadOptions{ + opts: driver.FileUploadOptions{ IndexName: indexName, ChunkSize: 100, Async: true, diff --git a/rag/types.go b/rag/driver/types.go similarity index 99% rename from rag/types.go rename to rag/driver/types.go index 140e50a..feb1b23 100644 --- a/rag/types.go +++ b/rag/driver/types.go @@ -1,4 +1,4 @@ -package rag +package driver import ( "context" diff --git a/rag/rag.go b/rag/rag.go new file mode 100644 index 0000000..1a4ae78 --- /dev/null +++ b/rag/rag.go @@ -0,0 +1,63 @@ +package rag + +import ( + "fmt" + "strconv" + + "github.com/yaoapp/gou/rag/driver" + "github.com/yaoapp/gou/rag/driver/openai" + "github.com/yaoapp/gou/rag/driver/qdrant" +) + +const ( + // DriverQdrant is the Qdrant vector store driver + DriverQdrant = "qdrant" + // DriverOpenAI is the OpenAI embeddings driver + DriverOpenAI = "openai" +) + +// NewEngine creates a new RAG engine instance +func NewEngine(driverName string, config driver.IndexConfig) (driver.Engine, error) { + switch driverName { + case DriverQdrant: + // Convert IndexConfig to qdrant.Config + qConfig := qdrant.Config{ + Host: config.Options["host"], + APIKey: config.Options["api_key"], + } + if portStr, ok := config.Options["port"]; ok { + if port, err := strconv.ParseUint(portStr, 10, 32); err == nil { + qConfig.Port = uint32(port) + } + } + return qdrant.NewEngine(qConfig) + default: + return nil, fmt.Errorf("unsupported engine driver: %s", driverName) + } +} + +// NewVectorizer creates a new vectorizer instance +func NewVectorizer(driverName string, config driver.VectorizeConfig) (driver.Vectorizer, error) { + switch driverName { + case DriverOpenAI: + return openai.New(openai.Config{ + APIKey: config.Options["api_key"], + Model: config.Model, + }) + default: + return nil, fmt.Errorf("unsupported vectorizer driver: %s", driverName) + } +} + +// NewFileUpload creates a new file upload instance +func NewFileUpload(driverName string, engine driver.Engine, vectorizer driver.Vectorizer) (driver.FileUpload, error) { + switch driverName { + case DriverQdrant: + if qEngine, ok := engine.(*qdrant.Engine); ok { + return qdrant.NewFileUpload(qEngine, vectorizer) + } + return nil, fmt.Errorf("engine type mismatch: expected *qdrant.Engine") + default: + return nil, fmt.Errorf("unsupported file upload driver: %s", driverName) + } +}