Skip to content

Commit

Permalink
enhance: (knowledge) only preload files on datasets if asked for (#336)
Browse files Browse the repository at this point in the history
  • Loading branch information
iwilltry42 authored Jan 13, 2025
1 parent fab4dda commit 62a74c4
Show file tree
Hide file tree
Showing 20 changed files with 55 additions and 25 deletions.
2 changes: 1 addition & 1 deletion knowledge/pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type IngestPathsOpts struct {
type Client interface {
CreateDataset(ctx context.Context, datasetID string, opts *types2.DatasetCreateOpts) (*types2.Dataset, error)
DeleteDataset(ctx context.Context, datasetID string) error
GetDataset(ctx context.Context, datasetID string) (*types2.Dataset, error)
GetDataset(ctx context.Context, datasetID string, opts *types2.DatasetGetOpts) (*types2.Dataset, error)
FindFile(ctx context.Context, searchFile types2.File) (*types2.File, error)
DeleteFile(ctx context.Context, datasetID, fileID string) error
ListDatasets(ctx context.Context) ([]types2.Dataset, error)
Expand Down
2 changes: 1 addition & 1 deletion knowledge/pkg/client/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ func AskDir(ctx context.Context, c Client, path string, query string, opts *Inge
func getOrCreateDataset(ctx context.Context, c Client, datasetID string, create bool) (*types.Dataset, error) {
var ds *types.Dataset
var err error
ds, err = c.GetDataset(ctx, datasetID)
ds, err = c.GetDataset(ctx, datasetID, nil)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions knowledge/pkg/client/standalone.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ func (c *StandaloneClient) DeleteDataset(ctx context.Context, datasetID string)
return c.Datastore.DeleteDataset(ctx, datasetID)
}

func (c *StandaloneClient) GetDataset(ctx context.Context, datasetID string) (*types2.Dataset, error) {
return c.Datastore.GetDataset(ctx, datasetID)
func (c *StandaloneClient) GetDataset(ctx context.Context, datasetID string, opts *types2.DatasetGetOpts) (*types2.Dataset, error) {
return c.Datastore.GetDataset(ctx, datasetID, opts)
}

func (c *StandaloneClient) ListDatasets(ctx context.Context) ([]types2.Dataset, error) {
Expand Down
2 changes: 1 addition & 1 deletion knowledge/pkg/cmd/edit_dataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func (s *ClientEditDataset) Run(cmd *cobra.Command, args []string) error {
datasetID := args[0]

// Get current dataset
dataset, err := c.GetDataset(cmd.Context(), datasetID)
dataset, err := c.GetDataset(cmd.Context(), datasetID, nil)
if err != nil {
return fmt.Errorf("failed to get dataset: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion knowledge/pkg/cmd/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (s *ClientExportDatasets) Run(cmd *cobra.Command, args []string) error {
}
} else {
for _, datasetID := range dsnames {
ds, err := c.GetDataset(cmd.Context(), datasetID)
ds, err := c.GetDataset(cmd.Context(), datasetID, nil)
if err != nil {
return err
}
Expand Down
3 changes: 2 additions & 1 deletion knowledge/pkg/cmd/get_dataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"

"github.com/gptscript-ai/knowledge/pkg/index/types"
"github.com/spf13/cobra"
)

Expand All @@ -28,7 +29,7 @@ func (s *ClientGetDataset) Run(cmd *cobra.Command, args []string) error {

datasetID := args[0]

ds, err := c.GetDataset(cmd.Context(), datasetID)
ds, err := c.GetDataset(cmd.Context(), datasetID, &types.DatasetGetOpts{IncludeFiles: true})
if err != nil {
return fmt.Errorf("failed to get dataset: %w", err)
}
Expand Down
6 changes: 3 additions & 3 deletions knowledge/pkg/datastore/dataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ func (s *Datastore) DeleteDataset(ctx context.Context, datasetID string) error {
return nil
}

func (s *Datastore) GetDataset(ctx context.Context, datasetID string) (*types.Dataset, error) {
return s.Index.GetDataset(ctx, datasetID)
func (s *Datastore) GetDataset(ctx context.Context, datasetID string, opts *types.DatasetGetOpts) (*types.Dataset, error) {
return s.Index.GetDataset(ctx, datasetID, opts)
}

func (s *Datastore) ListDatasets(ctx context.Context) ([]types.Dataset, error) {
Expand All @@ -61,7 +61,7 @@ func (s *Datastore) UpdateDataset(ctx context.Context, updatedDataset types.Data
return origDS, fmt.Errorf("dataset ID is required")
}

origDS, err = s.GetDataset(ctx, updatedDataset.ID)
origDS, err = s.GetDataset(ctx, updatedDataset.ID, nil)
if err != nil {
return origDS, err
}
Expand Down
2 changes: 1 addition & 1 deletion knowledge/pkg/datastore/ingest.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func (s *Datastore) Ingest(ctx context.Context, datasetID string, filename strin
statusLog := log.FromCtx(ctx).With("phase", "store")

// Get dataset
ds, err := s.GetDataset(ctx, datasetID)
ds, err := s.GetDataset(ctx, datasetID, nil)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion knowledge/pkg/datastore/retrieve.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func (s *Datastore) Retrieve(ctx context.Context, datasetIDs []string, query str
}

func (s *Datastore) SimilaritySearch(ctx context.Context, query string, numDocuments int, datasetID string, where map[string]string, whereDocument []chromem.WhereDocument) ([]types2.Document, error) {
ds, err := s.GetDataset(ctx, datasetID)
ds, err := s.GetDataset(ctx, datasetID, nil)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion knowledge/pkg/datastore/retrievers/bm25.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (r *BM25Retriever) Retrieve(ctx context.Context, store store.Store, query s
for _, datasetID := range datasetIDs {
// TODO: make configurable via RetrieveOpts
// silently ignore non-existent datasets
ds, err := store.GetDataset(ctx, datasetID)
ds, err := store.GetDataset(ctx, datasetID, nil)
if err != nil {
if strings.HasPrefix(err.Error(), "dataset not found") {
slog.Info("Dataset not found", "dataset", datasetID)
Expand Down
2 changes: 1 addition & 1 deletion knowledge/pkg/datastore/retrievers/retrievers.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (r *BasicRetriever) Retrieve(ctx context.Context, store store.Store, query
for _, dataset := range datasetIDs {
// TODO: make configurable via RetrieveOpts
// silently ignore non-existent datasets
ds, err := store.GetDataset(ctx, dataset)
ds, err := store.GetDataset(ctx, dataset, nil)
if err != nil {
if strings.HasPrefix(err.Error(), "dataset not found") {
continue
Expand Down
2 changes: 1 addition & 1 deletion knowledge/pkg/datastore/retrievers/routing.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (r *RoutingRetriever) Retrieve(ctx context.Context, store store.Store, quer

datasets := map[string]map[string]any{}
for _, dsID := range r.AvailableDatasets {
dataset, err := store.GetDataset(ctx, dsID)
dataset, err := store.GetDataset(ctx, dsID, nil)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion knowledge/pkg/datastore/retrievers/subquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func (s *SubqueryRetriever) Retrieve(ctx context.Context, store store.Store, que
for _, dataset := range datasetIDs {
// TODO: make configurable via RetrieveOpts
// silently ignore non-existent datasets
ds, err := store.GetDataset(ctx, dataset)
ds, err := store.GetDataset(ctx, dataset, nil)
if err != nil {
if strings.HasPrefix(err.Error(), "dataset not found") {
continue
Expand Down
2 changes: 1 addition & 1 deletion knowledge/pkg/datastore/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (

type Store interface {
ListDatasets(ctx context.Context) ([]types.Dataset, error)
GetDataset(ctx context.Context, datasetID string) (*types.Dataset, error)
GetDataset(ctx context.Context, datasetID string, opts *types.DatasetGetOpts) (*types.Dataset, error)
SimilaritySearch(ctx context.Context, query string, numDocuments int, collection string, where map[string]string, whereDocument []chromem.WhereDocument) ([]vs.Document, error)
GetDocuments(ctx context.Context, datasetID string, where map[string]string, whereDocument []chromem.WhereDocument) ([]vs.Document, error)
}
15 changes: 14 additions & 1 deletion knowledge/pkg/index/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"log"
"log/slog"
"os"
"strings"
"time"

Expand All @@ -15,14 +16,26 @@ import (
)

func New(ctx context.Context, dsn string, autoMigrate bool) (Index, error) {
gormLogLevel := logger.Silent
switch os.Getenv("GORM_LOG_LEVEL") {
case "silent":
gormLogLevel = logger.Silent
case "error":
gormLogLevel = logger.Error
case "warn":
gormLogLevel = logger.Warn
case "info":
gormLogLevel = logger.Info
}

var (
indexDB Index
err error
gormCfg = &gorm.Config{
Logger: logger.New(log.Default(), logger.Config{
SlowThreshold: 200 * time.Millisecond,
Colorful: true,
LogLevel: logger.Silent,
LogLevel: gormLogLevel,
}),
TranslateError: true,
}
Expand Down
2 changes: 1 addition & 1 deletion knowledge/pkg/index/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ type Index interface {

// Fundamental Dataset Operations
CreateDataset(ctx context.Context, dataset types.Dataset, opts *types.DatasetCreateOpts) error
GetDataset(ctx context.Context, datasetID string) (*types.Dataset, error)
GetDataset(ctx context.Context, datasetID string, opts *types.DatasetGetOpts) (*types.Dataset, error)
ListDatasets(ctx context.Context) ([]types.Dataset, error)
DeleteDataset(ctx context.Context, datasetID string) error

Expand Down
4 changes: 2 additions & 2 deletions knowledge/pkg/index/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ func (i *Index) CreateDataset(ctx context.Context, dataset types.Dataset, opts *
return i.DB.CreateDataset(ctx, dataset, opts)
}

func (i *Index) GetDataset(ctx context.Context, datasetID string) (*types.Dataset, error) {
return i.DB.GetDataset(ctx, datasetID)
func (i *Index) GetDataset(ctx context.Context, datasetID string, opts *types.DatasetGetOpts) (*types.Dataset, error) {
return i.DB.GetDataset(ctx, datasetID, opts)
}

func (i *Index) ListDatasets(ctx context.Context) ([]types.Dataset, error) {
Expand Down
4 changes: 2 additions & 2 deletions knowledge/pkg/index/sqlite/sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ func (i *Index) CreateDataset(ctx context.Context, dataset types.Dataset, opts *
return i.DB.CreateDataset(ctx, dataset, opts)
}

func (i *Index) GetDataset(ctx context.Context, datasetID string) (*types.Dataset, error) {
return i.DB.GetDataset(ctx, datasetID)
func (i *Index) GetDataset(ctx context.Context, datasetID string, opts *types.DatasetGetOpts) (*types.Dataset, error) {
return i.DB.GetDataset(ctx, datasetID, opts)
}

func (i *Index) ListDatasets(ctx context.Context) ([]types.Dataset, error) {
Expand Down
4 changes: 4 additions & 0 deletions knowledge/pkg/index/types/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ type DatasetCreateOpts struct {
ErrOnExists bool
}

type DatasetGetOpts struct {
IncludeFiles bool
}

// Dataset refers to a VectorDB data space.
// @Description Dataset refers to a VectorDB data space.
type Dataset struct {
Expand Down
16 changes: 14 additions & 2 deletions knowledge/pkg/index/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,21 @@ func (db *DB) DeleteDataset(ctx context.Context, datasetID string) error {
return nil
}

func (db *DB) GetDataset(ctx context.Context, datasetID string) (*Dataset, error) {
func (db *DB) GetDataset(ctx context.Context, datasetID string, opts *DatasetGetOpts) (*Dataset, error) {
dataset := &Dataset{}
tx := db.WithContext(ctx).Preload("Files.Documents").First(dataset, "id = ?", datasetID)
tx := db.WithContext(ctx)

if opts == nil {
opts = &DatasetGetOpts{
IncludeFiles: false,
}
}

if opts.IncludeFiles {
tx = tx.Preload("Files.Documents")
}

tx = tx.First(dataset, "id = ?", datasetID)
if tx.Error != nil {
if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
return nil, nil
Expand Down

0 comments on commit 62a74c4

Please sign in to comment.