diff --git a/rag/driver/qdrant/engine.go b/rag/driver/qdrant/engine.go index 6d246a0..e469cae 100644 --- a/rag/driver/qdrant/engine.go +++ b/rag/driver/qdrant/engine.go @@ -161,11 +161,34 @@ func (e *Engine) IndexDoc(ctx context.Context, indexName string, doc *driver.Doc } if doc.Metadata != nil { - metadataStruct, err := qdrant.NewStruct(doc.Metadata) - if err != nil { - return fmt.Errorf("failed to convert metadata: %w", err) + payload := make(map[string]*qdrant.Value) + for k, v := range doc.Metadata { + switch val := v.(type) { + case string: + payload[k] = qdrant.NewValueString(val) + case float64: + payload[k] = qdrant.NewValueDouble(val) + case bool: + payload[k] = qdrant.NewValueBool(val) + case []string: + values := make([]*qdrant.Value, len(val)) + for i, s := range val { + values[i] = qdrant.NewValueString(s) + } + payload[k] = &qdrant.Value{ + Kind: &qdrant.Value_ListValue{ + ListValue: &qdrant.ListValue{ + Values: values, + }, + }, + } + case map[string]interface{}: + if nested, err := qdrant.NewStruct(val); err == nil { + payload[k] = qdrant.NewValueStruct(nested) + } + } } - point.Payload["metadata"] = qdrant.NewValueStruct(metadataStruct) + point.Payload["metadata"] = qdrant.NewValueStruct(&qdrant.Struct{Fields: payload}) } _, err = e.client.Upsert(ctx, &qdrant.UpsertPoints{ @@ -354,6 +377,23 @@ func convertStructToMap(s *qdrant.Struct) map[string]interface{} { result[k] = x.DoubleValue case *qdrant.Value_BoolValue: result[k] = x.BoolValue + case *qdrant.Value_ListValue: + if x.ListValue != nil { + list := make([]interface{}, len(x.ListValue.Values)) + for i, lv := range x.ListValue.Values { + switch lx := lv.Kind.(type) { + case *qdrant.Value_StringValue: + list[i] = lx.StringValue + case *qdrant.Value_DoubleValue: + list[i] = lx.DoubleValue + case *qdrant.Value_BoolValue: + list[i] = lx.BoolValue + case *qdrant.Value_StructValue: + list[i] = convertStructToMap(lx.StructValue) + } + } + result[k] = list + } case *qdrant.Value_StructValue: result[k] = convertStructToMap(x.StructValue) } @@ -591,3 +631,36 @@ func (e *Engine) HasIndex(ctx context.Context, name string) (bool, error) { } return false, nil } + +// GetMetadata retrieves only the metadata of a document by its ID from the specified collection +func (e *Engine) GetMetadata(ctx context.Context, indexName string, DocID string) (map[string]interface{}, error) { + if err := e.checkContext(ctx); err != nil { + return nil, err + } + + points, err := e.client.Get(ctx, &qdrant.GetPoints{ + CollectionName: indexName, + Ids: []*qdrant.PointId{qdrant.NewIDNum(stringToUint64ID(DocID))}, + WithPayload: qdrant.NewWithPayload(true), + WithVectors: qdrant.NewWithVectors(false), // Don't fetch vectors to save memory + }) + if err != nil { + if strings.Contains(err.Error(), "doesn't exist") { + return nil, fmt.Errorf("collection doesn't exist: %w", err) + } + return nil, fmt.Errorf("failed to get document metadata: %w", err) + } + + if len(points) == 0 { + return nil, fmt.Errorf("document not found") + } + + point := points[0] + if metadataValue := point.Payload["metadata"]; metadataValue != nil { + if metadataStruct := metadataValue.GetStructValue(); metadataStruct != nil { + return convertStructToMap(metadataStruct), nil + } + } + + return make(map[string]interface{}), nil +} diff --git a/rag/driver/qdrant/engine_test.go b/rag/driver/qdrant/engine_test.go index 96558f8..079a97b 100644 --- a/rag/driver/qdrant/engine_test.go +++ b/rag/driver/qdrant/engine_test.go @@ -596,3 +596,67 @@ func TestQdrantEngineErrors(t *testing.T) { assert.Error(t, err) assert.Equal(t, "empty batch", err.Error()) } + +// TestGetMetadata tests the GetMetadata functionality +func TestGetMetadata(t *testing.T) { + ctx := context.Background() + config := getTestConfig(t) + + engine, err := NewEngine(config) + assert.NoError(t, err) + defer engine.Close() + + indexName := fmt.Sprintf("test_metadata_index_%d", time.Now().UnixNano()) + err = engine.CreateIndex(ctx, driver.IndexConfig{Name: indexName}) + assert.NoError(t, err) + defer engine.DeleteIndex(ctx, indexName) + + // Test document with metadata + doc := &driver.Document{ + DocID: "test-doc-metadata", + Content: "Test document with metadata", + Metadata: map[string]interface{}{ + "type": "test", + "version": 1.0, + "tags": []string{"test", "metadata"}, + "nested": map[string]interface{}{ + "key": "value", + }, + }, + } + + // Index the document + err = engine.IndexDoc(ctx, indexName, doc) + assert.NoError(t, err) + + // Test GetMetadata + metadata, err := engine.GetMetadata(ctx, indexName, doc.DocID) + assert.NoError(t, err) + assert.NotNil(t, metadata) + assert.Equal(t, "test", metadata["type"]) + assert.Equal(t, 1.0, metadata["version"]) + + // Test GetMetadata with non-existent document + _, err = engine.GetMetadata(ctx, indexName, "non-existent-doc") + assert.Error(t, err) + assert.Contains(t, err.Error(), "document not found") + + // Test GetMetadata with non-existent collection + _, err = engine.GetMetadata(ctx, "non-existent-index", doc.DocID) + assert.Error(t, err) + assert.Contains(t, err.Error(), "collection doesn't exist") + + // Test GetMetadata with nil context + _, err = engine.GetMetadata(nil, indexName, doc.DocID) + assert.Error(t, err) + assert.Equal(t, "nil context", err.Error()) + + // Test GetMetadata after engine is closed + closedEngine, err := NewEngine(config) + assert.NoError(t, err) + err = closedEngine.Close() + assert.NoError(t, err) + _, err = closedEngine.GetMetadata(ctx, indexName, doc.DocID) + assert.Error(t, err) + assert.Equal(t, "engine is closed", err.Error()) +} diff --git a/rag/driver/types.go b/rag/driver/types.go index 57be794..330d175 100644 --- a/rag/driver/types.go +++ b/rag/driver/types.go @@ -94,6 +94,7 @@ type Engine interface { DeleteDoc(ctx context.Context, indexName string, DocID string) error DeleteBatch(ctx context.Context, indexName string, DocIDs []string) (string, error) // Returns TaskID HasDocument(ctx context.Context, indexName string, DocID string) (bool, error) + GetMetadata(ctx context.Context, indexName string, DocID string) (map[string]interface{}, error) // Get document metadata only // Task operations GetTaskInfo(ctx context.Context, taskID string) (*TaskInfo, error) diff --git a/rag/rag_test.go b/rag/rag_test.go index 2888538..32f0f4e 100644 --- a/rag/rag_test.go +++ b/rag/rag_test.go @@ -300,6 +300,12 @@ func TestRAGIntegration(t *testing.T) { assert.Equal(t, doc.Content, retrieved.Content) assert.Equal(t, doc.DocID, retrieved.DocID) + // Test GetMetadata + metadata, err := engine.GetMetadata(ctx, indexConfig.Name, doc.DocID) + assert.NoError(t, err) + assert.NotNil(t, metadata) + assert.Equal(t, "test", metadata["type"]) + // Search results, err := engine.Search(ctx, indexConfig.Name, nil, driver.VectorSearchOptions{ QueryText: "test document",