Skip to content

Commit

Permalink
perf(catalog): enhance the stability of embeddings saving
Browse files Browse the repository at this point in the history
  • Loading branch information
Yougigun committed Oct 31, 2024
1 parent 6c97540 commit 2b2ec52
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 28 deletions.
35 changes: 30 additions & 5 deletions pkg/service/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ const chunkLength = 1024
const chunkOverlap = 200
const NamespaceID = "preset"


// Note: this pipeline is for the old indexing pipeline
const ConvertDocToMDPipelineID = "indexing-convert-pdf"
const DocToMDVersion = "v1.1.1"

// TODO: the pipeline id is not correct, need to update the pipeline id
const ConvertDocToMDPipelineID2 = "indexing-advanced-convert-doc"

// TODO: the version is not correct, need to update the version
const DocToMDVersion2 = "v1.0.1"

Expand Down Expand Up @@ -300,13 +300,29 @@ func (s *Service) SplitTextPipe(ctx context.Context, caller uuid.UUID, requester
return filteredResult, nil
}

// EmbeddingTextPipe uses the embedding pipeline to convert text into vectors and consume caller's credits.
// It processes the input texts in batches, triggers the embedding pipeline for each batch, and collects the results.
// The function returns a 2D slice of float32 representing the vectors for the input texts.
// EmbeddingTextPipe converts multiple text inputs into vector embeddings using a pipeline service.
// It processes texts in parallel batches for efficiency while managing resource usage.
//
// Parameters:
// - ctx: Context for the operation
// - caller: UUID of the calling user
// - requester: UUID of the requesting entity (optional)
// - texts: Slice of strings to be converted to embeddings
//
// Returns:
// - [][]float32: 2D slice where each inner slice is a vector embedding
// - error: Any error encountered during processing
//
// The function:
// - Processes texts in batches of 32
// - Limits concurrent processing to 5 goroutines
// - Maintains input order in the output
// - Cancels all operations if any batch fails
func (s *Service) EmbeddingTextPipe(ctx context.Context, caller uuid.UUID, requester uuid.UUID, texts []string) ([][]float32, error) {
ctx, ctxCancel := context.WithCancel(ctx)
defer ctxCancel()
const maxBatchSize = 32
const maxConcurrentGoroutines = 5
var md metadata.MD
if requester != uuid.Nil {
md = metadata.New(map[string]string{
Expand Down Expand Up @@ -338,19 +354,28 @@ func (s *Service) EmbeddingTextPipe(ctx context.Context, caller uuid.UUID, reque
// - Extract the vector from the response.
// - Send the result to the results channel.
// If an error occurs, send the error to the error channel.
// Create a semaphore channel to limit concurrent goroutines to maxConcurrentGoroutines
sem := make(chan struct{}, maxConcurrentGoroutines)
for i := 0; i < len(texts); i += maxBatchSize {

end := i + maxBatchSize
if end > len(texts) {
end = len(texts)
}
batch := texts[i:end]
batchIndex := i / maxBatchSize


// Acquire semaphore before starting goroutine
sem <- struct{}{}
wg.Add(1)
go utils.GoRecover(func() {
// Release semaphore when goroutine completes
defer func() { <-sem }()
defer wg.Done()

func(batch []string, index int) {
ctx_ := metadata.NewOutgoingContext(ctx, md)
defer wg.Done()

inputs := make([]*structpb.Struct, 0, len(batch))
for _, text := range batch {
Expand Down
111 changes: 88 additions & 23 deletions pkg/worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -652,9 +652,34 @@ func (wp *fileToEmbWorkerPool) processChunkingFile(ctx context.Context, file rep

}

// processEmbeddingFile processes a file with embedding status.
// It retrieves chunks from MinIO, calls the embedding pipeline, saves the embeddings into the vector database and metadata into the database,
// and updates the file status to completed in the database.
// processEmbeddingFile processes a file that is ready for embedding by:
// 1. Validating the file's process status is "EMBEDDING"
// 2. Retrieving text chunks from MinIO storage and database metadata
// - Will retry once if initial chunk retrieval fails
//
// 3. Updating file metadata with embedding pipeline version info
// - Uses TextEmbedPipelineID and TextEmbedVersion from service config
//
// 4. Calling the embedding pipeline to generate vectors from text chunks
// - Uses file creator and requester UIDs for pipeline execution
//
// 5. Saving embeddings to vector database (Milvus) and metadata to SQL database
// - Creates embeddings collection named after knowledge base UID
// - Links embeddings to source text chunks and file metadata
//
// 6. Updating file status to "COMPLETED" in database
//
// Parameters:
// - ctx: Context for the operation
// - file: KnowledgeBaseFile struct containing file metadata
//
// Returns:
// - updatedFile: Updated KnowledgeBaseFile after processing
// - nextStatus: Next file process status (COMPLETED if successful)
// - err: Error if any step fails
//
// The function handles errors at each step and returns appropriate status codes.
// If chunk retrieval fails initially, it will retry once after a 1 second delay.
func (wp *fileToEmbWorkerPool) processEmbeddingFile(ctx context.Context, file repository.KnowledgeBaseFile) (updatedFile *repository.KnowledgeBaseFile, nextStatus artifactpb.FileProcessStatus, err error) {
logger, _ := logger.GetZapLogger(ctx)
// check the file status is embedding
Expand Down Expand Up @@ -822,34 +847,74 @@ type MilvusEmbedding struct {
}

// saveEmbeddings saves embeddings into the vector database and updates the metadata in the database.
// Processes embeddings in batches of 50 to avoid timeout issues.
const batchSize = 50

func (wp *fileToEmbWorkerPool) saveEmbeddings(ctx context.Context, kbUID string, embeddings []repository.Embedding) error {
logger, _ := logger.GetZapLogger(ctx)
externalServiceCall := func(embUIDs []string) error {
// save the embeddings into vector database
milvusEmbeddings := make([]milvus.Embedding, len(embeddings))
for i, emb := range embeddings {
milvusEmbeddings[i] = milvus.Embedding{
SourceTable: emb.SourceTable,
SourceUID: emb.SourceUID.String(),
EmbeddingUID: emb.UID.String(),
Vector: emb.Vector,
if len(embeddings) == 0 {
logger.Debug("No embeddings to save")
return nil
}

totalEmbeddings := len(embeddings)

// Process embeddings in batches
for i := 0; i < totalEmbeddings; i += batchSize {
// Add context check
if err := ctx.Err(); err != nil {
return fmt.Errorf("context cancelled while processing embeddings: %w", err)
}

end := i + batchSize
if end > totalEmbeddings {
end = totalEmbeddings
}

currentBatch := embeddings[i:end]

externalServiceCall := func(_ []string) error {
// save the embeddings into vector database
milvusEmbeddings := make([]milvus.Embedding, len(currentBatch))
for j, emb := range currentBatch {
milvusEmbeddings[j] = milvus.Embedding{
SourceTable: emb.SourceTable,
SourceUID: emb.SourceUID.String(),
EmbeddingUID: emb.UID.String(),
Vector: emb.Vector,
}
}
err := wp.svc.MilvusClient.InsertVectorsToKnowledgeBaseCollection(ctx, kbUID, milvusEmbeddings)
if err != nil {
logger.Error("Failed to save embeddings batch into vector database.",
zap.String("KbUID", kbUID),
zap.Int("batch", i/batchSize+1),
zap.Int("batchSize", len(currentBatch)))
return err
}
return nil
}
err := wp.svc.MilvusClient.InsertVectorsToKnowledgeBaseCollection(ctx, kbUID, milvusEmbeddings)

_, err := wp.svc.Repository.UpsertEmbeddings(ctx, currentBatch, externalServiceCall)
if err != nil {
logger.Error("Failed to save embeddings into vector database.", zap.String("KbUID", kbUID))
logger.Error("Failed to save embeddings batch into vector database and metadata into database.",
zap.String("KbUID", kbUID),
zap.Int("batch", i/batchSize+1),
zap.Int("batchSize", len(currentBatch)))
return err
}
return nil
}
_, err := wp.svc.Repository.UpsertEmbeddings(ctx, embeddings, externalServiceCall)
if err != nil {
logger.Error("Failed to save embeddings into vector database and metadata into database.", zap.String("KbUID", kbUID))
return err

logger.Info("Embeddings batch saved successfully",
zap.String("KbUID", kbUID),
zap.Int("batch", i/batchSize+1),
zap.Int("batchSize", len(currentBatch)),
zap.Int("progress", end),
zap.Int("total", totalEmbeddings))
}
// info how many embeddings saved in which kb
logger.Info("Embeddings saved into vector database and metadata into database.",
zap.String("KbUID", kbUID), zap.Int("Embeddings count", len(embeddings)))

logger.Info("All embeddings saved into vector database and metadata into database.",
zap.String("KbUID", kbUID),
zap.Int("total embeddings", totalEmbeddings))
return nil
}

Expand Down

0 comments on commit 2b2ec52

Please sign in to comment.