Skip to content

Commit

Permalink
Add GetMetadata method to RAG engine and corresponding tests
Browse files Browse the repository at this point in the history
This update introduces the GetMetadata method to the RAG engine, allowing retrieval of document metadata by its ID. The implementation includes error handling for non-existent documents and collections, as well as checks for nil context and closed engine states. Integration tests have been added to validate the GetMetadata functionality, ensuring it behaves correctly under various scenarios. This enhancement improves the metadata management capabilities of the RAG engine.
  • Loading branch information
trheyi committed Jan 2, 2025
1 parent 18e7497 commit aeeeb6a
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 4 deletions.
81 changes: 77 additions & 4 deletions rag/driver/qdrant/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,34 @@ func (e *Engine) IndexDoc(ctx context.Context, indexName string, doc *driver.Doc
}

if doc.Metadata != nil {
metadataStruct, err := qdrant.NewStruct(doc.Metadata)
if err != nil {
return fmt.Errorf("failed to convert metadata: %w", err)
payload := make(map[string]*qdrant.Value)
for k, v := range doc.Metadata {
switch val := v.(type) {
case string:
payload[k] = qdrant.NewValueString(val)
case float64:
payload[k] = qdrant.NewValueDouble(val)
case bool:
payload[k] = qdrant.NewValueBool(val)
case []string:
values := make([]*qdrant.Value, len(val))
for i, s := range val {
values[i] = qdrant.NewValueString(s)
}
payload[k] = &qdrant.Value{
Kind: &qdrant.Value_ListValue{
ListValue: &qdrant.ListValue{
Values: values,
},
},
}
case map[string]interface{}:
if nested, err := qdrant.NewStruct(val); err == nil {
payload[k] = qdrant.NewValueStruct(nested)
}
}
}
point.Payload["metadata"] = qdrant.NewValueStruct(metadataStruct)
point.Payload["metadata"] = qdrant.NewValueStruct(&qdrant.Struct{Fields: payload})
}

_, err = e.client.Upsert(ctx, &qdrant.UpsertPoints{
Expand Down Expand Up @@ -354,6 +377,23 @@ func convertStructToMap(s *qdrant.Struct) map[string]interface{} {
result[k] = x.DoubleValue
case *qdrant.Value_BoolValue:
result[k] = x.BoolValue
case *qdrant.Value_ListValue:
if x.ListValue != nil {
list := make([]interface{}, len(x.ListValue.Values))
for i, lv := range x.ListValue.Values {
switch lx := lv.Kind.(type) {
case *qdrant.Value_StringValue:
list[i] = lx.StringValue
case *qdrant.Value_DoubleValue:
list[i] = lx.DoubleValue
case *qdrant.Value_BoolValue:
list[i] = lx.BoolValue
case *qdrant.Value_StructValue:
list[i] = convertStructToMap(lx.StructValue)
}
}
result[k] = list
}
case *qdrant.Value_StructValue:
result[k] = convertStructToMap(x.StructValue)
}
Expand Down Expand Up @@ -591,3 +631,36 @@ func (e *Engine) HasIndex(ctx context.Context, name string) (bool, error) {
}
return false, nil
}

// GetMetadata retrieves only the metadata of a document by its ID from the specified collection
func (e *Engine) GetMetadata(ctx context.Context, indexName string, DocID string) (map[string]interface{}, error) {
if err := e.checkContext(ctx); err != nil {
return nil, err
}

points, err := e.client.Get(ctx, &qdrant.GetPoints{
CollectionName: indexName,
Ids: []*qdrant.PointId{qdrant.NewIDNum(stringToUint64ID(DocID))},
WithPayload: qdrant.NewWithPayload(true),
WithVectors: qdrant.NewWithVectors(false), // Don't fetch vectors to save memory
})
if err != nil {
if strings.Contains(err.Error(), "doesn't exist") {
return nil, fmt.Errorf("collection doesn't exist: %w", err)
}
return nil, fmt.Errorf("failed to get document metadata: %w", err)
}

if len(points) == 0 {
return nil, fmt.Errorf("document not found")
}

point := points[0]
if metadataValue := point.Payload["metadata"]; metadataValue != nil {
if metadataStruct := metadataValue.GetStructValue(); metadataStruct != nil {
return convertStructToMap(metadataStruct), nil
}
}

return make(map[string]interface{}), nil
}
64 changes: 64 additions & 0 deletions rag/driver/qdrant/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -596,3 +596,67 @@ func TestQdrantEngineErrors(t *testing.T) {
assert.Error(t, err)
assert.Equal(t, "empty batch", err.Error())
}

// TestGetMetadata tests the GetMetadata functionality
func TestGetMetadata(t *testing.T) {
ctx := context.Background()
config := getTestConfig(t)

engine, err := NewEngine(config)
assert.NoError(t, err)
defer engine.Close()

indexName := fmt.Sprintf("test_metadata_index_%d", time.Now().UnixNano())
err = engine.CreateIndex(ctx, driver.IndexConfig{Name: indexName})
assert.NoError(t, err)
defer engine.DeleteIndex(ctx, indexName)

// Test document with metadata
doc := &driver.Document{
DocID: "test-doc-metadata",
Content: "Test document with metadata",
Metadata: map[string]interface{}{
"type": "test",
"version": 1.0,
"tags": []string{"test", "metadata"},
"nested": map[string]interface{}{
"key": "value",
},
},
}

// Index the document
err = engine.IndexDoc(ctx, indexName, doc)
assert.NoError(t, err)

// Test GetMetadata
metadata, err := engine.GetMetadata(ctx, indexName, doc.DocID)
assert.NoError(t, err)
assert.NotNil(t, metadata)
assert.Equal(t, "test", metadata["type"])
assert.Equal(t, 1.0, metadata["version"])

// Test GetMetadata with non-existent document
_, err = engine.GetMetadata(ctx, indexName, "non-existent-doc")
assert.Error(t, err)
assert.Contains(t, err.Error(), "document not found")

// Test GetMetadata with non-existent collection
_, err = engine.GetMetadata(ctx, "non-existent-index", doc.DocID)
assert.Error(t, err)
assert.Contains(t, err.Error(), "collection doesn't exist")

// Test GetMetadata with nil context
_, err = engine.GetMetadata(nil, indexName, doc.DocID)
assert.Error(t, err)
assert.Equal(t, "nil context", err.Error())

// Test GetMetadata after engine is closed
closedEngine, err := NewEngine(config)
assert.NoError(t, err)
err = closedEngine.Close()
assert.NoError(t, err)
_, err = closedEngine.GetMetadata(ctx, indexName, doc.DocID)
assert.Error(t, err)
assert.Equal(t, "engine is closed", err.Error())
}
1 change: 1 addition & 0 deletions rag/driver/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ type Engine interface {
DeleteDoc(ctx context.Context, indexName string, DocID string) error
DeleteBatch(ctx context.Context, indexName string, DocIDs []string) (string, error) // Returns TaskID
HasDocument(ctx context.Context, indexName string, DocID string) (bool, error)
GetMetadata(ctx context.Context, indexName string, DocID string) (map[string]interface{}, error) // Get document metadata only

// Task operations
GetTaskInfo(ctx context.Context, taskID string) (*TaskInfo, error)
Expand Down
6 changes: 6 additions & 0 deletions rag/rag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,12 @@ func TestRAGIntegration(t *testing.T) {
assert.Equal(t, doc.Content, retrieved.Content)
assert.Equal(t, doc.DocID, retrieved.DocID)

// Test GetMetadata
metadata, err := engine.GetMetadata(ctx, indexConfig.Name, doc.DocID)
assert.NoError(t, err)
assert.NotNil(t, metadata)
assert.Equal(t, "test", metadata["type"])

// Search
results, err := engine.Search(ctx, indexConfig.Name, nil, driver.VectorSearchOptions{
QueryText: "test document",
Expand Down

0 comments on commit aeeeb6a

Please sign in to comment.