diff --git a/.github/workflows/ci-check-repo.yaml b/.github/workflows/ci-check-repo.yaml index 770ecbba2e1..fa16ce85111 100644 --- a/.github/workflows/ci-check-repo.yaml +++ b/.github/workflows/ci-check-repo.yaml @@ -152,8 +152,6 @@ jobs: cwd: "." pull: "--ff" - name: Check generated protobufs - env: - USE_BAZEL_VERSION: 7.4.0 working-directory: ./proto env: USE_BAZEL_VERSION: 7.4.0 diff --git a/go/cmd/dolt/commands/engine/sqlengine.go b/go/cmd/dolt/commands/engine/sqlengine.go index d37f25e212b..19aee1af7ea 100644 --- a/go/cmd/dolt/commands/engine/sqlengine.go +++ b/go/cmd/dolt/commands/engine/sqlengine.go @@ -47,6 +47,7 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/sqle/statspro" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/writer" "github.com/dolthub/dolt/go/libraries/utils/config" + "github.com/dolthub/dolt/go/libraries/utils/filesys" ) // SqlEngine packages up the context necessary to run sql queries against dsqle. @@ -55,6 +56,7 @@ type SqlEngine struct { contextFactory contextFactory dsessFactory sessionFactory engine *gms.Engine + fs filesys.Filesys } type sessionFactory func(mysqlSess *sql.BaseSession, pro sql.DatabaseProvider) (*dsess.DoltSession, error) @@ -194,6 +196,7 @@ func NewSqlEngine( sqlEngine.contextFactory = sqlContextFactory() sqlEngine.dsessFactory = sessFactory sqlEngine.engine = engine + sqlEngine.fs = pro.FileSystem() // configuring stats depends on sessionBuilder // sessionBuilder needs ref to statsProv @@ -314,8 +317,15 @@ func (se *SqlEngine) GetUnderlyingEngine() *gms.Engine { return se.engine } +func (se *SqlEngine) FileSystem() filesys.Filesys { + return se.fs +} + func (se *SqlEngine) Close() error { if se.engine != nil { + if se.engine.Analyzer.Catalog.BinlogReplicaController != nil { + dblr.DoltBinlogReplicaController.Close() + } return se.engine.Close() } return nil diff --git a/go/cmd/dolt/commands/sqlserver/server.go b/go/cmd/dolt/commands/sqlserver/server.go index 9374608c37e..20e538c8e9a 100644 --- a/go/cmd/dolt/commands/sqlserver/server.go +++ b/go/cmd/dolt/commands/sqlserver/server.go @@ -559,21 +559,27 @@ func ConfigureServices( } listenaddr := fmt.Sprintf(":%d", port) + sqlContextInterceptor := sqle.SqlContextServerInterceptor{ + Factory: sqlEngine.NewDefaultContext, + } args := remotesrv.ServerArgs{ Logger: logrus.NewEntry(lgr), ReadOnly: apiReadOnly || serverConfig.ReadOnly(), HttpListenAddr: listenaddr, GrpcListenAddr: listenaddr, ConcurrencyControl: remotesapi.PushConcurrencyControl_PUSH_CONCURRENCY_CONTROL_ASSERT_WORKING_SET, + Options: sqlContextInterceptor.Options(), + HttpInterceptor: sqlContextInterceptor.HTTP(nil), } var err error - args.FS, args.DBCache, err = sqle.RemoteSrvFSAndDBCache(sqlEngine.NewDefaultContext, sqle.DoNotCreateUnknownDatabases) + args.FS = sqlEngine.FileSystem() + args.DBCache, err = sqle.RemoteSrvDBCache(sqle.GetInterceptorSqlContext, sqle.DoNotCreateUnknownDatabases) if err != nil { lgr.Errorf("error creating SQL engine context for remotesapi server: %v", err) return err } - authenticator := newAccessController(sqlEngine.NewDefaultContext, sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb) + authenticator := newAccessController(sqle.GetInterceptorSqlContext, sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb) args = sqle.WithUserPasswordAuth(args, authenticator) args.TLSConfig = serverConf.TLSConfig @@ -621,6 +627,7 @@ func ConfigureServices( lgr.Errorf("error creating SQL engine context for remotesapi server: %v", err) return err } + args.FS = sqlEngine.FileSystem() clusterRemoteSrvTLSConfig, err := LoadClusterTLSConfig(serverConfig.ClusterConfig()) if err != nil { @@ -634,7 +641,7 @@ func ConfigureServices( lgr.Errorf("error creating remotesapi server on port %d: %v", *serverConfig.RemotesapiPort(), err) return err } - clusterController.RegisterGrpcServices(sqlEngine.NewDefaultContext, clusterRemoteSrv.srv.GrpcServer()) + clusterController.RegisterGrpcServices(sqle.GetInterceptorSqlContext, clusterRemoteSrv.srv.GrpcServer()) clusterRemoteSrv.lis, err = clusterRemoteSrv.srv.Listeners() if err != nil { diff --git a/go/cmd/dolt/doltversion/version.go b/go/cmd/dolt/doltversion/version.go index 8e9fa83276f..e6561764635 100644 --- a/go/cmd/dolt/doltversion/version.go +++ b/go/cmd/dolt/doltversion/version.go @@ -16,5 +16,5 @@ package doltversion const ( - Version = "1.47.1" + Version = "1.48.0" ) diff --git a/go/go.mod b/go/go.mod index ca0a435c3e7..2fc96289a25 100644 --- a/go/go.mod +++ b/go/go.mod @@ -56,7 +56,7 @@ require ( github.com/cespare/xxhash/v2 v2.2.0 github.com/creasty/defaults v1.6.0 github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 - github.com/dolthub/go-mysql-server v0.19.1-0.20250123004221-f5a5bcea7eed + github.com/dolthub/go-mysql-server v0.19.1-0.20250131110511-67aa2a430366 github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 github.com/dolthub/swiss v0.1.0 github.com/esote/minmaxheap v1.0.0 diff --git a/go/go.sum b/go/go.sum index e18888690e4..0af441a1c7c 100644 --- a/go/go.sum +++ b/go/go.sum @@ -179,8 +179,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U= github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0= github.com/dolthub/go-icu-regex v0.0.0-20241215010122-db690dd53c90 h1:Sni8jrP0sy/w9ZYXoff4g/ixe+7bFCZlfCqXKJSU+zM= github.com/dolthub/go-icu-regex v0.0.0-20241215010122-db690dd53c90/go.mod h1:ylU4XjUpsMcvl/BKeRRMXSH7e7WBrPXdSLvnRJYrxEA= -github.com/dolthub/go-mysql-server v0.19.1-0.20250123004221-f5a5bcea7eed h1:2EQHWtMkjyN/SNfbg/nh/a0RANq8V8gxNynYum2Kq+s= -github.com/dolthub/go-mysql-server v0.19.1-0.20250123004221-f5a5bcea7eed/go.mod h1:jYEJ8tNkA7K3k39X8iMqaX3MSMmViRgh222JSLHDgVc= +github.com/dolthub/go-mysql-server v0.19.1-0.20250131110511-67aa2a430366 h1:pJ+upgX6hrhyqgpkmk9Ye9lIPSualMHZcUMs8kWknV4= +github.com/dolthub/go-mysql-server v0.19.1-0.20250131110511-67aa2a430366/go.mod h1:jYEJ8tNkA7K3k39X8iMqaX3MSMmViRgh222JSLHDgVc= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q= github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718 h1:lT7hE5k+0nkBdj/1UOSFwjWpNxf+LCApbRHgnCA17XE= diff --git a/go/libraries/doltcore/doltdb/commit_hooks_test.go b/go/libraries/doltcore/doltdb/commit_hooks_test.go index fa20ac9c52e..fae89ab9824 100644 --- a/go/libraries/doltcore/doltdb/commit_hooks_test.go +++ b/go/libraries/doltcore/doltdb/commit_hooks_test.go @@ -43,7 +43,7 @@ func TestPushOnWriteHook(t *testing.T) { ctx := context.Background() // destination repo - testDir, err := test.ChangeToTestDir("TestReplicationDest") + testDir, err := test.ChangeToTestDir(t.TempDir(), "TestReplicationDest") if err != nil { panic("Couldn't change the working directory to the test directory.") @@ -62,7 +62,7 @@ func TestPushOnWriteHook(t *testing.T) { destDB, _ := LoadDoltDB(context.Background(), types.Format_Default, LocalDirDoltDB, filesys.LocalFS) // source repo - testDir, err = test.ChangeToTestDir("TestReplicationSource") + testDir, err = test.ChangeToTestDir(t.TempDir(), "TestReplicationSource") if err != nil { panic("Couldn't change the working directory to the test directory.") @@ -183,7 +183,7 @@ func TestAsyncPushOnWrite(t *testing.T) { ctx := context.Background() // destination repo - testDir, err := test.ChangeToTestDir("TestReplicationDest") + testDir, err := test.ChangeToTestDir(t.TempDir(), "TestReplicationDest") if err != nil { panic("Couldn't change the working directory to the test directory.") @@ -202,7 +202,7 @@ func TestAsyncPushOnWrite(t *testing.T) { destDB, _ := LoadDoltDB(context.Background(), types.Format_Default, LocalDirDoltDB, filesys.LocalFS) // source repo - testDir, err = test.ChangeToTestDir("TestReplicationSource") + testDir, err = test.ChangeToTestDir(t.TempDir(), "TestReplicationSource") if err != nil { panic("Couldn't change the working directory to the test directory.") diff --git a/go/libraries/doltcore/doltdb/doltdb_test.go b/go/libraries/doltcore/doltdb/doltdb_test.go index ef1f09de50e..c04abe1be23 100644 --- a/go/libraries/doltcore/doltdb/doltdb_test.go +++ b/go/libraries/doltcore/doltdb/doltdb_test.go @@ -219,7 +219,7 @@ func TestEmptyInMemoryRepoCreation(t *testing.T) { } func TestLoadNonExistentLocalFSRepo(t *testing.T) { - _, err := test.ChangeToTestDir("TestLoadRepo") + _, err := test.ChangeToTestDir(t.TempDir(), "TestLoadRepo") if err != nil { panic("Couldn't change the working directory to the test directory.") @@ -231,7 +231,7 @@ func TestLoadNonExistentLocalFSRepo(t *testing.T) { } func TestLoadBadLocalFSRepo(t *testing.T) { - testDir, err := test.ChangeToTestDir("TestLoadRepo") + testDir, err := test.ChangeToTestDir(t.TempDir(), "TestLoadRepo") if err != nil { panic("Couldn't change the working directory to the test directory.") @@ -246,7 +246,7 @@ func TestLoadBadLocalFSRepo(t *testing.T) { } func TestLDNoms(t *testing.T) { - testDir, err := test.ChangeToTestDir("TestLoadRepo") + testDir, err := test.ChangeToTestDir(t.TempDir(), "TestLoadRepo") if err != nil { panic("Couldn't change the working directory to the test directory.") diff --git a/go/libraries/doltcore/dtestutils/environment.go b/go/libraries/doltcore/dtestutils/environment.go index c2d1eddf86b..9555afab831 100644 --- a/go/libraries/doltcore/dtestutils/environment.go +++ b/go/libraries/doltcore/dtestutils/environment.go @@ -16,7 +16,6 @@ package dtestutils import ( "context" - "os" "path/filepath" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" @@ -41,12 +40,7 @@ func CreateTestEnv() *env.DoltEnv { // CreateTestEnvForLocalFilesystem creates a new DoltEnv for testing, using a local FS, instead of an in-memory // filesystem, for persisting files. This is useful for tests that require a disk-based filesystem and will not // work correctly with an in-memory filesystem and in-memory blob store (e.g. dolt_undrop() tests). -func CreateTestEnvForLocalFilesystem() *env.DoltEnv { - tempDir, err := os.MkdirTemp(os.TempDir(), "dolt-*") - if err != nil { - panic(err) - } - +func CreateTestEnvForLocalFilesystem(tempDir string) *env.DoltEnv { fs, err := filesys.LocalFilesysWithWorkingDir(tempDir) if err != nil { panic(err) diff --git a/go/libraries/doltcore/dtestutils/sql_server_driver/cmd.go b/go/libraries/doltcore/dtestutils/sql_server_driver/cmd.go index 802abf595ca..75c6cd936ad 100644 --- a/go/libraries/doltcore/dtestutils/sql_server_driver/cmd.go +++ b/go/libraries/doltcore/dtestutils/sql_server_driver/cmd.go @@ -76,13 +76,9 @@ type DoltUser struct { var _ DoltCmdable = DoltUser{} var _ DoltDebuggable = DoltUser{} -func NewDoltUser() (DoltUser, error) { - tmpdir, err := os.MkdirTemp("", "go-sql-server-driver-") - if err != nil { - return DoltUser{}, err - } +func NewDoltUser(tmpdir string) (DoltUser, error) { res := DoltUser{tmpdir} - err = res.DoltExec("config", "--global", "--add", "metrics.disabled", "true") + err := res.DoltExec("config", "--global", "--add", "metrics.disabled", "true") if err != nil { return DoltUser{}, err } diff --git a/go/libraries/doltcore/env/actions/tag.go b/go/libraries/doltcore/env/actions/tag.go index 1580ad43808..a6bb57e8c5e 100644 --- a/go/libraries/doltcore/env/actions/tag.go +++ b/go/libraries/doltcore/env/actions/tag.go @@ -25,6 +25,8 @@ import ( "github.com/dolthub/dolt/go/store/datas" ) +const DefaultPageSize = 100 + type TagProps struct { TaggerName string TaggerEmail string @@ -97,6 +99,30 @@ func DeleteTagsOnDB(ctx context.Context, ddb *doltdb.DoltDB, tagNames ...string) return nil } +// IterUnresolvedTags iterates over tags in dEnv.DoltDB, and calls cb() for each with an unresolved Tag. +func IterUnresolvedTags(ctx context.Context, ddb *doltdb.DoltDB, cb func(tag *doltdb.TagResolver) (stop bool, err error)) error { + tagRefs, err := ddb.GetTags(ctx) + if err != nil { + return err + } + + tagResolvers, err := ddb.GetTagResolvers(ctx, tagRefs) + if err != nil { + return err + } + + for _, tagResolver := range tagResolvers { + stop, err := cb(&tagResolver) + if err != nil { + return err + } + if stop { + break + } + } + return nil +} + // IterResolvedTags iterates over tags in dEnv.DoltDB from newest to oldest, resolving the tag to a commit and calling cb(). func IterResolvedTags(ctx context.Context, ddb *doltdb.DoltDB, cb func(tag *doltdb.Tag) (stop bool, err error)) error { tagRefs, err := ddb.GetTags(ctx) @@ -138,26 +164,81 @@ func IterResolvedTags(ctx context.Context, ddb *doltdb.DoltDB, cb func(tag *dolt return nil } -// IterUnresolvedTags iterates over tags in dEnv.DoltDB, and calls cb() for each with an unresovled Tag. -func IterUnresolvedTags(ctx context.Context, ddb *doltdb.DoltDB, cb func(tag *doltdb.TagResolver) (stop bool, err error)) error { +// IterResolvedTagsPaginated iterates over tags in dEnv.DoltDB in their default lexicographical order, resolving the tag to a commit and calling cb(). +// Returns the next tag name if there are more results available. +func IterResolvedTagsPaginated(ctx context.Context, ddb *doltdb.DoltDB, startTag string, cb func(tag *doltdb.Tag) (stop bool, err error)) (string, error) { + // tags returned here are sorted lexicographically tagRefs, err := ddb.GetTags(ctx) if err != nil { - return err + return "", err } - tagResolvers, err := ddb.GetTagResolvers(ctx, tagRefs) - if err != nil { - return err + // find starting index based on start tag + startIdx := 0 + if startTag != "" { + for i, tr := range tagRefs { + if tr.GetPath() == startTag { + startIdx = i + 1 // start after the given tag + break + } + } } - for _, tagResolver := range tagResolvers { - stop, err := cb(&tagResolver) + // get page of results + endIdx := startIdx + DefaultPageSize + if endIdx > len(tagRefs) { + endIdx = len(tagRefs) + } + + pageTagRefs := tagRefs[startIdx:endIdx] + + // resolve tags for this page + for _, tr := range pageTagRefs { + tag, err := ddb.ResolveTag(ctx, tr.(ref.TagRef)) if err != nil { - return err + return "", err } + + stop, err := cb(tag) + if err != nil { + return "", err + } + if stop { break } } - return nil + + // return next tag name if there are more results + if endIdx < len(tagRefs) { + lastTag := pageTagRefs[len(pageTagRefs)-1] + return lastTag.GetPath(), nil + } + + return "", nil +} + +// VisitResolvedTag iterates over tags in ddb until the given tag name is found, then calls cb() with the resolved tag. +func VisitResolvedTag(ctx context.Context, ddb *doltdb.DoltDB, tagName string, cb func(tag *doltdb.Tag) error) error { + tagRefs, err := ddb.GetTags(ctx) + if err != nil { + return err + } + + for _, r := range tagRefs { + tr, ok := r.(ref.TagRef) + if !ok { + return fmt.Errorf("DoltDB.GetTags() returned non-tag DoltRef") + } + + if tr.GetPath() == tagName { + tag, err := ddb.ResolveTag(ctx, tr) + if err != nil { + return err + } + return cb(tag) + } + } + + return doltdb.ErrTagNotFound } diff --git a/go/libraries/doltcore/env/actions/tag_test.go b/go/libraries/doltcore/env/actions/tag_test.go new file mode 100644 index 00000000000..6ad86189abb --- /dev/null +++ b/go/libraries/doltcore/env/actions/tag_test.go @@ -0,0 +1,150 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package actions + +import ( + "context" + "fmt" + "sort" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" + "github.com/dolthub/dolt/go/libraries/doltcore/env" + "github.com/dolthub/dolt/go/libraries/utils/filesys" + "github.com/dolthub/dolt/go/store/types" +) + +const ( + testHomeDir = "/user/bheni" + workingDir = "/user/bheni/datasets/addresses" + credsDir = "creds" + + configFile = "config.json" + GlobalConfigFile = "config_global.json" +) + +func testHomeDirFunc() (string, error) { + return testHomeDir, nil +} + +func createTestEnv() (*env.DoltEnv, *filesys.InMemFS) { + initialDirs := []string{testHomeDir, workingDir} + initialFiles := map[string][]byte{} + + fs := filesys.NewInMemFS(initialDirs, initialFiles, workingDir) + dEnv := env.Load(context.Background(), testHomeDirFunc, fs, doltdb.InMemDoltDB, "test") + + return dEnv, fs +} + +func TestVisitResolvedTag(t *testing.T) { + dEnv, _ := createTestEnv() + ctx := context.Background() + + // Initialize repo + err := dEnv.InitRepo(ctx, types.Format_Default, "test user", "test@test.com", "main") + require.NoError(t, err) + + // Create a tag + tagName := "test-tag" + tagMsg := "test tag message" + err = CreateTag(ctx, dEnv, tagName, "main", TagProps{TaggerName: "test user", TaggerEmail: "test@test.com", Description: tagMsg}) + require.NoError(t, err) + + // Visit the tag and verify its properties + var foundTag *doltdb.Tag + err = VisitResolvedTag(ctx, dEnv.DoltDB, tagName, func(tag *doltdb.Tag) error { + foundTag = tag + return nil + }) + require.NoError(t, err) + require.NotNil(t, foundTag) + require.Equal(t, tagName, foundTag.Name) + require.Equal(t, tagMsg, foundTag.Meta.Description) + + // Test visiting non-existent tag + err = VisitResolvedTag(ctx, dEnv.DoltDB, "non-existent-tag", func(tag *doltdb.Tag) error { + return nil + }) + require.Equal(t, doltdb.ErrTagNotFound, err) +} + +func TestIterResolvedTagsPaginated(t *testing.T) { + dEnv, _ := createTestEnv() + ctx := context.Background() + + // Initialize repo + err := dEnv.InitRepo(ctx, types.Format_Default, "test user", "test@test.com", "main") + require.NoError(t, err) + + expectedTagNames := make([]string, DefaultPageSize*2) + // Create multiple tags with different timestamps + tagNames := make([]string, DefaultPageSize*2) + for i := range tagNames { + tagName := fmt.Sprintf("tag-%d", i) + err = CreateTag(ctx, dEnv, tagName, "main", TagProps{ + TaggerName: "test user", + TaggerEmail: "test@test.com", + Description: fmt.Sprintf("test tag %s", tagName), + }) + time.Sleep(2 * time.Millisecond) + require.NoError(t, err) + tagNames[i] = tagName + expectedTagNames[i] = tagName + } + + // Sort expected tag names to ensure they are in the correct order + sort.Strings(expectedTagNames) + + // Test first page + var foundTags []string + pageToken, err := IterResolvedTagsPaginated(ctx, dEnv.DoltDB, "", func(tag *doltdb.Tag) (bool, error) { + foundTags = append(foundTags, tag.Name) + return false, nil + }) + require.NoError(t, err) + require.NotEmpty(t, pageToken) // Should have next page + require.Equal(t, DefaultPageSize, len(foundTags)) // Default page size tags returned + require.Equal(t, expectedTagNames[:DefaultPageSize], foundTags) + + // Test second page + var secondPageTags []string + nextPageToken, err := IterResolvedTagsPaginated(ctx, dEnv.DoltDB, pageToken, func(tag *doltdb.Tag) (bool, error) { + secondPageTags = append(secondPageTags, tag.Name) + return false, nil + }) + + require.NoError(t, err) + require.Empty(t, nextPageToken) // Should be no more pages + require.Equal(t, DefaultPageSize, len(secondPageTags)) // Remaining tags + require.Equal(t, expectedTagNames[DefaultPageSize:], secondPageTags) + + // Verify all tags were found + allFoundTags := append(foundTags, secondPageTags...) + require.Equal(t, len(tagNames), len(allFoundTags)) + require.Equal(t, expectedTagNames, allFoundTags) + + // Test early termination + var earlyTermTags []string + _, err = IterResolvedTagsPaginated(ctx, dEnv.DoltDB, "", func(tag *doltdb.Tag) (bool, error) { + earlyTermTags = append(earlyTermTags, tag.Name) + return true, nil // Stop after first tag + }) + require.NoError(t, err) + require.Equal(t, 1, len(earlyTermTags)) +} diff --git a/go/libraries/doltcore/env/multi_repo_env_test.go b/go/libraries/doltcore/env/multi_repo_env_test.go index 9a5b6bdb82b..d97baeacd8b 100644 --- a/go/libraries/doltcore/env/multi_repo_env_test.go +++ b/go/libraries/doltcore/env/multi_repo_env_test.go @@ -115,7 +115,7 @@ func initRepoWithRelativePath(t *testing.T, envPath string, hdp HomeDirProvider) } func TestMultiEnvForDirectory(t *testing.T) { - rootPath, err := test.ChangeToTestDir("TestDoltEnvAsMultiEnv") + rootPath, err := test.ChangeToTestDir(t.TempDir(), "TestDoltEnvAsMultiEnv") require.NoError(t, err) hdp := func() (string, error) { return rootPath, nil } @@ -150,7 +150,7 @@ func TestMultiEnvForDirectory(t *testing.T) { } func TestMultiEnvForDirectoryWithMultipleRepos(t *testing.T) { - rootPath, err := test.ChangeToTestDir("TestDoltEnvAsMultiEnvWithMultipleRepos") + rootPath, err := test.ChangeToTestDir(t.TempDir(), "TestDoltEnvAsMultiEnvWithMultipleRepos") require.NoError(t, err) hdp := func() (string, error) { return rootPath, nil } @@ -177,7 +177,7 @@ func TestMultiEnvForDirectoryWithMultipleRepos(t *testing.T) { } func initMultiEnv(t *testing.T, testName string, names []string) (string, HomeDirProvider, map[string]*DoltEnv) { - rootPath, err := test.ChangeToTestDir(testName) + rootPath, err := test.ChangeToTestDir(t.TempDir(), testName) require.NoError(t, err) hdp := func() (string, error) { return rootPath, nil } diff --git a/go/libraries/doltcore/sqle/binlogreplication/binlog_metadata_persistence.go b/go/libraries/doltcore/sqle/binlogreplication/binlog_metadata_persistence.go index 01c6b38d0fa..3737498ff1f 100644 --- a/go/libraries/doltcore/sqle/binlogreplication/binlog_metadata_persistence.go +++ b/go/libraries/doltcore/sqle/binlogreplication/binlog_metadata_persistence.go @@ -104,7 +104,9 @@ func persistReplicaRunningState(ctx *sql.Context, state replicaRunningState) err // loadReplicationConfiguration loads the replication configuration for default channel ("") from // the "mysql" database, |mysqlDb|. -func loadReplicationConfiguration(_ *sql.Context, mysqlDb *mysql_db.MySQLDb) (*mysql_db.ReplicaSourceInfo, error) { +func loadReplicationConfiguration(ctx *sql.Context, mysqlDb *mysql_db.MySQLDb) (*mysql_db.ReplicaSourceInfo, error) { + sql.SessionCommandBegin(ctx.Session) + defer sql.SessionCommandEnd(ctx.Session) rd := mysqlDb.Reader() defer rd.Close() diff --git a/go/libraries/doltcore/sqle/binlogreplication/binlog_primary_test.go b/go/libraries/doltcore/sqle/binlogreplication/binlog_primary_test.go index 02cb6a0a77e..9b8770befa1 100644 --- a/go/libraries/doltcore/sqle/binlogreplication/binlog_primary_test.go +++ b/go/libraries/doltcore/sqle/binlogreplication/binlog_primary_test.go @@ -430,7 +430,7 @@ func TestBinlogPrimary_ReplicaRestart(t *testing.T) { // Restart the MySQL replica and reconnect to the Dolt primary prevPrimaryDatabase := primaryDatabase var err error - mySqlPort, mySqlProcess, err = startMySqlServer(testDir) + mySqlPort, mySqlProcess, err = startMySqlServer(t, testDir) require.NoError(t, err) replicaDatabase = primaryDatabase primaryDatabase = prevPrimaryDatabase @@ -1042,7 +1042,7 @@ func outputReplicaApplierStatus(t *testing.T) { newRows, err := replicaDatabase.Queryx("select * from performance_schema.replication_applier_status_by_worker;") require.NoError(t, err) allNewRows := readAllRowsIntoMaps(t, newRows) - fmt.Printf("\n\nreplication_applier_status_by_worker: %v\n", allNewRows) + t.Logf("\n\nreplication_applier_status_by_worker: %v\n", allNewRows) } // outputShowReplicaStatus prints out replica status information. This is useful for debugging @@ -1052,7 +1052,7 @@ func outputShowReplicaStatus(t *testing.T) { newRows, err := replicaDatabase.Queryx("show replica status;") require.NoError(t, err) allNewRows := readAllRowsIntoMaps(t, newRows) - fmt.Printf("\n\nSHOW REPLICA STATUS: %v\n", allNewRows) + t.Logf("\n\nSHOW REPLICA STATUS: %v\n", allNewRows) } // copyMap returns a copy of the specified map |m|. @@ -1098,7 +1098,7 @@ func waitForReplicaToReconnect(t *testing.T) { func mustRestartDoltPrimaryServer(t *testing.T) { var err error prevReplicaDatabase := replicaDatabase - doltPort, doltProcess, err = startDoltSqlServer(testDir, nil) + doltPort, doltProcess, err = startDoltSqlServer(t, testDir, nil) require.NoError(t, err) primaryDatabase = replicaDatabase replicaDatabase = prevReplicaDatabase @@ -1109,7 +1109,7 @@ func mustRestartDoltPrimaryServer(t *testing.T) { func mustRestartMySqlReplicaServer(t *testing.T) { var err error prevPrimaryDatabase := primaryDatabase - mySqlPort, mySqlProcess, err = startMySqlServer(testDir) + mySqlPort, mySqlProcess, err = startMySqlServer(t, testDir) require.NoError(t, err) replicaDatabase = primaryDatabase primaryDatabase = prevPrimaryDatabase diff --git a/go/libraries/doltcore/sqle/binlogreplication/binlog_replica_applier.go b/go/libraries/doltcore/sqle/binlogreplication/binlog_replica_applier.go index 3833aecc6aa..36bce2c5767 100644 --- a/go/libraries/doltcore/sqle/binlogreplication/binlog_replica_applier.go +++ b/go/libraries/doltcore/sqle/binlogreplication/binlog_replica_applier.go @@ -19,6 +19,7 @@ import ( "io" "strconv" "strings" + "sync" "sync/atomic" "time" @@ -63,6 +64,7 @@ type binlogReplicaApplier struct { currentPosition *mysql.Position // successfully executed GTIDs filters *filterConfiguration running atomic.Bool + handlerWg sync.WaitGroup engine *gms.Engine dbsWithUncommittedChanges map[string]struct{} } @@ -88,10 +90,14 @@ const rowFlag_rowsAreComplete = 0x0008 // Go spawns a new goroutine to run the applier's binlog event handler. func (a *binlogReplicaApplier) Go(ctx *sql.Context) { + if !a.running.CompareAndSwap(false, true) { + panic("attempt to start binlogReplicaApplier while it is already running") + } + a.handlerWg.Add(1) go func() { - a.running.Store(true) + defer a.handlerWg.Done() + defer a.running.Store(false) err := a.replicaBinlogEventHandler(ctx) - a.running.Store(false) if err != nil { ctx.GetLogger().Errorf("unexpected error of type %T: '%v'", err, err.Error()) DoltBinlogReplicaController.setSqlError(mysql.ERUnknownError, err.Error()) @@ -104,6 +110,27 @@ func (a *binlogReplicaApplier) IsRunning() bool { return a.running.Load() } +// Stop will shutdown the replication thread if it is running. This is not safe to call concurrently |Go|. +// This is used by the controller when implementing STOP REPLICA, but it is also used on shutdown when the +// replication thread should be shutdown cleanly in the event that it is still running. +func (a *binlogReplicaApplier) Stop() { + if a.IsRunning() { + // We jump through some hoops here. It is not the case that the replication thread + // is guaranteed to read from |stopReplicationChan|. Instead, it can exit on its + // own with an error, for example, after exceeding connection retry attempts. + done := make(chan struct{}) + go func() { + defer close(done) + a.handlerWg.Wait() + }() + select { + case a.stopReplicationChan <- struct{}{}: + case <-done: + } + a.handlerWg.Wait() + } +} + // connectAndStartReplicationEventStream connects to the configured MySQL replication source, including pausing // and retrying if errors are encountered. func (a *binlogReplicaApplier) connectAndStartReplicationEventStream(ctx *sql.Context) (*mysql.Conn, error) { @@ -263,25 +290,21 @@ func (a *binlogReplicaApplier) startReplicationEventStream(ctx *sql.Context, con func (a *binlogReplicaApplier) replicaBinlogEventHandler(ctx *sql.Context) error { engine := a.engine - var conn *mysql.Conn var eventProducer *binlogEventProducer // Process binlog events for { - if conn == nil { + if eventProducer == nil { ctx.GetLogger().Debug("no binlog connection to source, attempting to establish one") - if eventProducer != nil { - eventProducer.Stop() - } - var err error - if conn, err = a.connectAndStartReplicationEventStream(ctx); err == ErrReplicationStopped { + if conn, err := a.connectAndStartReplicationEventStream(ctx); err == ErrReplicationStopped { return nil } else if err != nil { return err + } else { + eventProducer = newBinlogEventProducer(conn) + eventProducer.Go(ctx) } - eventProducer = newBinlogEventProducer(conn) - eventProducer.Go(ctx) } select { @@ -305,8 +328,6 @@ func (a *binlogReplicaApplier) replicaBinlogEventHandler(ctx *sql.Context) error }) eventProducer.Stop() eventProducer = nil - conn.Close() - conn = nil } } else { // otherwise, log the error if it's something we don't expect and continue @@ -317,6 +338,7 @@ func (a *binlogReplicaApplier) replicaBinlogEventHandler(ctx *sql.Context) error case <-a.stopReplicationChan: ctx.GetLogger().Trace("received stop replication signal") eventProducer.Stop() + eventProducer = nil return nil } } @@ -325,6 +347,8 @@ func (a *binlogReplicaApplier) replicaBinlogEventHandler(ctx *sql.Context) error // processBinlogEvent processes a single binlog event message and returns an error if there were any problems // processing it. func (a *binlogReplicaApplier) processBinlogEvent(ctx *sql.Context, engine *gms.Engine, event mysql.BinlogEvent) error { + sql.SessionCommandBegin(ctx.Session) + defer sql.SessionCommandEnd(ctx.Session) var err error createCommit := false diff --git a/go/libraries/doltcore/sqle/binlogreplication/binlog_replica_controller.go b/go/libraries/doltcore/sqle/binlogreplication/binlog_replica_controller.go index 8e22a09cd92..5ea8edd9681 100644 --- a/go/libraries/doltcore/sqle/binlogreplication/binlog_replica_controller.go +++ b/go/libraries/doltcore/sqle/binlogreplication/binlog_replica_controller.go @@ -157,7 +157,9 @@ func (d *doltBinlogReplicaController) StartReplica(ctx *sql.Context) error { // changes and execute DDL statements on the running server. If the account doesn't exist, it will be // created and locked to disable log ins, and if it does exist, but is missing super privs or is not // locked, it will be given superuser privs and locked. -func (d *doltBinlogReplicaController) configureReplicationUser(_ *sql.Context) { +func (d *doltBinlogReplicaController) configureReplicationUser(ctx *sql.Context) { + sql.SessionCommandBegin(ctx.Session) + defer sql.SessionCommandEnd(ctx.Session) mySQLDb := d.engine.Analyzer.Catalog.MySQLDb ed := mySQLDb.Editor() defer ed.Close() @@ -180,12 +182,15 @@ func (d *doltBinlogReplicaController) SetEngine(engine *sqle.Engine) { // StopReplica implements the BinlogReplicaController interface. func (d *doltBinlogReplicaController) StopReplica(ctx *sql.Context) error { + d.operationMutex.Lock() + defer d.operationMutex.Unlock() + if d.applier.IsRunning() == false { ctx.Warn(3084, "Replication thread(s) for channel '' are already stopped.") return nil } - d.applier.stopReplicationChan <- struct{}{} + d.applier.Stop() d.updateStatus(func(status *binlogreplication.ReplicaStatus) { status.ReplicaIoRunning = binlogreplication.ReplicaIoNotRunning @@ -428,6 +433,17 @@ func (d *doltBinlogReplicaController) AutoStart(_ context.Context) error { return d.StartReplica(d.ctx) } +// Release all resources, such as replication threads, associated with the replication. +// This can only be done once in the lifecycle of the instance. Because DoltBinlogReplicaController +// is currently a global singleton, this should only be done once in the lifecycle of the +// application. +func (d *doltBinlogReplicaController) Close() { + d.applier.Stop() + if d.ctx != nil { + sql.SessionEnd(d.ctx.Session) + } +} + // // Helper functions // diff --git a/go/libraries/doltcore/sqle/binlogreplication/binlog_replica_event_producer.go b/go/libraries/doltcore/sqle/binlogreplication/binlog_replica_event_producer.go index 34f45eb3abf..872a7ccba31 100644 --- a/go/libraries/doltcore/sqle/binlogreplication/binlog_replica_event_producer.go +++ b/go/libraries/doltcore/sqle/binlogreplication/binlog_replica_event_producer.go @@ -15,6 +15,7 @@ package binlogreplication import ( + "sync" "sync/atomic" "github.com/dolthub/go-mysql-server/sql" @@ -30,19 +31,24 @@ type binlogEventProducer struct { conn *mysql.Conn errorChan chan error eventChan chan mysql.BinlogEvent + closeChan chan struct{} + wg sync.WaitGroup running atomic.Bool } // newBinlogEventProducer creates a new binlog event producer that reads from the specified, established MySQL // connection |conn|. The returned binlogEventProducer owns the communication channels // and is responsible for closing them when the binlogEventProducer is stopped. +// +// The BinlogEventProducer will take ownership of the supplied |*Conn| instance and +// will |Close| it when the producer itself exits. func newBinlogEventProducer(conn *mysql.Conn) *binlogEventProducer { producer := &binlogEventProducer{ conn: conn, eventChan: make(chan mysql.BinlogEvent), errorChan: make(chan error), + closeChan: make(chan struct{}), } - producer.running.Store(true) return producer } @@ -61,7 +67,14 @@ func (p *binlogEventProducer) ErrorChan() <-chan error { // Go starts this binlogEventProducer in a new goroutine. Right before this routine exits, it will close the // two communication channels it owns. func (p *binlogEventProducer) Go(_ *sql.Context) { + if !p.running.CompareAndSwap(false, true) { + panic("attempt to start binlogEventProducer more than once.") + } + p.wg.Add(1) go func() { + defer p.wg.Done() + defer close(p.errorChan) + defer close(p.eventChan) for p.IsRunning() { // ReadBinlogEvent blocks until a binlog event can be read and returned, so this has to be done on a // separate thread, otherwise the applier would be blocked and wouldn't be able to handle the STOP @@ -75,13 +88,19 @@ func (p *binlogEventProducer) Go(_ *sql.Context) { } if err != nil { - p.errorChan <- err + select { + case p.errorChan <- err: + case <-p.closeChan: + return + } } else { - p.eventChan <- event + select { + case p.eventChan <- event: + case <-p.closeChan: + return + } } } - close(p.errorChan) - close(p.eventChan) }() } @@ -92,5 +111,9 @@ func (p *binlogEventProducer) IsRunning() bool { // Stop requests for this binlogEventProducer to stop processing events as soon as possible. func (p *binlogEventProducer) Stop() { - p.running.Store(false) + if p.running.CompareAndSwap(true, false) { + p.conn.Close() + close(p.closeChan) + } + p.wg.Wait() } diff --git a/go/libraries/doltcore/sqle/binlogreplication/binlog_replication_restart_test.go b/go/libraries/doltcore/sqle/binlogreplication/binlog_replication_restart_test.go index 70810019f7c..1ebb0a046ca 100644 --- a/go/libraries/doltcore/sqle/binlogreplication/binlog_replication_restart_test.go +++ b/go/libraries/doltcore/sqle/binlogreplication/binlog_replication_restart_test.go @@ -49,7 +49,7 @@ func TestBinlogReplicationServerRestart(t *testing.T) { time.Sleep(1000 * time.Millisecond) var err error - doltPort, doltProcess, err = startDoltSqlServer(testDir, nil) + doltPort, doltProcess, err = startDoltSqlServer(t, testDir, nil) require.NoError(t, err) // Check replication status on the replica and assert configuration persisted diff --git a/go/libraries/doltcore/sqle/binlogreplication/binlog_replication_test.go b/go/libraries/doltcore/sqle/binlogreplication/binlog_replication_test.go index 9231185dcdc..55eae2086ed 100644 --- a/go/libraries/doltcore/sqle/binlogreplication/binlog_replication_test.go +++ b/go/libraries/doltcore/sqle/binlogreplication/binlog_replication_test.go @@ -29,6 +29,7 @@ import ( "slices" "strconv" "strings" + "sync" "syscall" "testing" "time" @@ -47,7 +48,6 @@ var mySqlProcess, doltProcess *os.Process var doltLogFilePath, oldDoltLogFilePath, mysqlLogFilePath string var doltLogFile, mysqlLogFile *os.File var testDir string -var originalWorkingDir string // doltReplicaSystemVars are the common system variables that need // to be set on a Dolt replica before replication is turned on. @@ -55,6 +55,48 @@ var doltReplicaSystemVars = map[string]string{ "server_id": "42", } +func TestMain(m *testing.M) { + res := func() int { + defer func() { + cachedDoltDevBuildPathOnce.Do(func() {}) + if cachedDoltDevBuildPath != "" { + os.RemoveAll(filepath.Dir(cachedDoltDevBuildPath)) + } + }() + return m.Run() + }() + os.Exit(res) +} + +var cachedDoltDevBuildPath string +var cachedDoltDevBuildPathOnce sync.Once + +func DoltDevBuildPath() string { + cachedDoltDevBuildPathOnce.Do(func() { + tmp, err := os.MkdirTemp("", "binlog-replication-doltbin-") + if err != nil { + panic(err) + } + fullpath := filepath.Join(tmp, "dolt") + + originalWorkingDir, err := os.Getwd() + if err != nil { + panic(err) + } + + goDirPath := filepath.Join(originalWorkingDir, "..", "..", "..", "..") + + cmd := exec.Command("go", "build", "-o", fullpath, "./cmd/dolt") + cmd.Dir = goDirPath + output, err := cmd.CombinedOutput() + if err != nil { + panic("unable to build dolt for binlog integration tests: " + err.Error() + "\nFull output: " + string(output) + "\n") + } + cachedDoltDevBuildPath = fullpath + }) + return cachedDoltDevBuildPath +} + func teardown(t *testing.T) { if mySqlProcess != nil { stopMySqlServer(t) @@ -72,17 +114,17 @@ func teardown(t *testing.T) { // Output server logs on failure for easier debugging if t.Failed() { if oldDoltLogFilePath != "" { - fmt.Printf("\nDolt server log from %s:\n", oldDoltLogFilePath) - printFile(oldDoltLogFilePath) + t.Logf("\nDolt server log from %s:\n", oldDoltLogFilePath) + printFile(t, oldDoltLogFilePath) } - fmt.Printf("\nDolt server log from %s:\n", doltLogFilePath) - printFile(doltLogFilePath) - fmt.Printf("\nMySQL server log from %s:\n", mysqlLogFilePath) - printFile(mysqlLogFilePath) + t.Logf("\nDolt server log from %s:\n", doltLogFilePath) + printFile(t, doltLogFilePath) + t.Logf("\nMySQL server log from %s:\n", mysqlLogFilePath) + printFile(t, mysqlLogFilePath) mysqlErrorLogFilePath := filepath.Join(filepath.Dir(mysqlLogFilePath), "error_log.err") - fmt.Printf("\nMySQL server error log from %s:\n", mysqlErrorLogFilePath) - printFile(mysqlErrorLogFilePath) + t.Logf("\nMySQL server error log from %s:\n", mysqlErrorLogFilePath) + printFile(t, mysqlErrorLogFilePath) } else { // clean up temp files on clean test runs defer os.RemoveAll(testDir) @@ -194,7 +236,7 @@ func TestAutoRestartReplica(t *testing.T) { // Restart the Dolt replica stopDoltSqlServer(t) var err error - doltPort, doltProcess, err = startDoltSqlServer(testDir, nil) + doltPort, doltProcess, err = startDoltSqlServer(t, testDir, nil) require.NoError(t, err) // Assert that some test data replicates correctly @@ -218,7 +260,7 @@ func TestAutoRestartReplica(t *testing.T) { // Restart the Dolt replica stopDoltSqlServer(t) - doltPort, doltProcess, err = startDoltSqlServer(testDir, nil) + doltPort, doltProcess, err = startDoltSqlServer(t, testDir, nil) require.NoError(t, err) // SHOW REPLICA STATUS should show that replication is NOT running, with no errors @@ -590,11 +632,13 @@ func TestCharsetsAndCollations(t *testing.T) { // Test Helper Functions // -// waitForReplicaToCatchUp waits (up to 30s) for the replica to catch up with the primary database. The -// lag is measured by checking that gtid_executed is the same on the primary and replica. +// waitForReplicaToCatchUp waits for the replica to catch up with the primary database. The +// lag is measured by checking that gtid_executed is the same on the primary and replica. If +// no progress is made in 30 seconds, this function will fail the test. func waitForReplicaToCatchUp(t *testing.T) { timeLimit := 30 * time.Second + lastReplicaGtid := "" endTime := time.Now().Add(timeLimit) for time.Now().Before(endTime) { replicaGtid := queryGtid(t, replicaDatabase) @@ -602,8 +646,11 @@ func waitForReplicaToCatchUp(t *testing.T) { if primaryGtid == replicaGtid { return + } else if lastReplicaGtid != replicaGtid { + lastReplicaGtid = replicaGtid + endTime = time.Now().Add(timeLimit) } else { - fmt.Printf("primary and replica not in sync yet... (primary: %s, replica: %s)\n", primaryGtid, replicaGtid) + t.Logf("primary and replica not in sync yet... (primary: %s, replica: %s)\n", primaryGtid, replicaGtid) time.Sleep(250 * time.Millisecond) } } @@ -639,7 +686,7 @@ func waitForReplicaToReachGtid(t *testing.T, target int) { } } - fmt.Printf("replica has not reached transaction %d yet; currently at: %s \n", target, replicaGtid) + t.Logf("replica has not reached transaction %d yet; currently at: %s \n", target, replicaGtid) } t.Fatal("replica did not reach target GTID within " + timeLimit.String()) @@ -725,20 +772,13 @@ func startSqlServersWithDoltSystemVars(t *testing.T, doltPersistentSystemVars ma testDir = filepath.Join(os.TempDir(), fmt.Sprintf("%s-%v", t.Name(), time.Now().Unix())) err := os.MkdirAll(testDir, 0777) - - cmd := exec.Command("chmod", "777", testDir) - _, err = cmd.Output() - if err != nil { - panic(err) - } - require.NoError(t, err) - fmt.Printf("temp dir: %v \n", testDir) + t.Logf("temp dir: %v \n", testDir) // Start up primary and replica databases - mySqlPort, mySqlProcess, err = startMySqlServer(testDir) + mySqlPort, mySqlProcess, err = startMySqlServer(t, testDir) require.NoError(t, err) - doltPort, doltProcess, err = startDoltSqlServer(testDir, doltPersistentSystemVars) + doltPort, doltProcess, err = startDoltSqlServer(t, testDir, doltPersistentSystemVars) require.NoError(t, err) } @@ -856,25 +896,9 @@ func findFreePort() int { // startMySqlServer configures a starts a fresh MySQL server instance and returns the port it is running on, // and the os.Process handle. If unable to start up the MySQL server, an error is returned. -func startMySqlServer(dir string) (int, *os.Process, error) { - originalCwd, err := os.Getwd() - if err != nil { - panic(err) - } - - dir = dir + string(os.PathSeparator) + "mysql" + string(os.PathSeparator) - dataDir := dir + "mysql_data" - err = os.MkdirAll(dir, 0777) - if err != nil { - return -1, nil, err - } - cmd := exec.Command("chmod", "777", dir) - output, err := cmd.Output() - if err != nil { - panic(err) - } - - err = os.Chdir(dir) +func startMySqlServer(t *testing.T, dir string) (int, *os.Process, error) { + dir = filepath.Join(dir, "mysql") + err := os.MkdirAll(dir, 0777) if err != nil { return -1, nil, err } @@ -889,28 +913,31 @@ func startMySqlServer(dir string) (int, *os.Process, error) { } username := user.Username if username == "root" { - fmt.Printf("overriding current user (root) to run mysql as 'mysql' user instead\n") + t.Logf("overriding current user (root) to run mysql as 'mysql' user instead\n") username = "mysql" } + dataDir := filepath.Join(dir, "mysql_data") + // Check to see if the MySQL data directory has the "mysql" directory in it, which // tells us whether this MySQL instance has been initialized yet or not. initialized := directoryExists(filepath.Join(dataDir, "mysql")) if !initialized { // Create a fresh MySQL server for the primary - chmodCmd := exec.Command("mysqld", + initCmd := exec.Command("mysqld", "--no-defaults", "--user="+username, "--initialize-insecure", "--datadir="+dataDir, "--default-authentication-plugin=mysql_native_password") - output, err = chmodCmd.CombinedOutput() + initCmd.Dir = dir + output, err := initCmd.CombinedOutput() if err != nil { - return -1, nil, fmt.Errorf("unable to execute command %v: %v – %v", cmd.String(), err.Error(), string(output)) + return -1, nil, fmt.Errorf("unable to execute command %v: %v – %v", initCmd.String(), err.Error(), string(output)) } } - cmd = exec.Command("mysqld", + cmd := exec.Command("mysqld", "--no-defaults", "--user="+username, "--datadir="+dataDir, @@ -920,17 +947,18 @@ func startMySqlServer(dir string) (int, *os.Process, error) { fmt.Sprintf("--port=%v", mySqlPort), "--server-id=11223344", fmt.Sprintf("--socket=mysql-%v.sock", mySqlPort), - "--general_log_file="+dir+"general_log", - "--slow_query_log_file="+dir+"slow_query_log", + "--general_log_file="+filepath.Join(dir, "general_log"), + "--slow_query_log_file="+filepath.Join(dir, "slow_query_log"), "--log-error="+dir+"error_log", - fmt.Sprintf("--pid-file="+dir+"pid-%v.pid", mySqlPort)) + fmt.Sprintf("--pid-file="+filepath.Join(dir, "pid-%v.pid"), mySqlPort)) + cmd.Dir = dir mysqlLogFilePath = filepath.Join(dir, fmt.Sprintf("mysql-%d.out.log", time.Now().Unix())) mysqlLogFile, err = os.Create(mysqlLogFilePath) if err != nil { return -1, nil, err } - fmt.Printf("MySQL server logs at: %s \n", mysqlLogFilePath) + t.Logf("MySQL server logs at: %s \n", mysqlLogFilePath) cmd.Stdout = mysqlLogFile cmd.Stderr = mysqlLogFile err = cmd.Start() @@ -941,7 +969,7 @@ func startMySqlServer(dir string) (int, *os.Process, error) { dsn := fmt.Sprintf("root@tcp(127.0.0.1:%v)/", mySqlPort) primaryDatabase = sqlx.MustOpen("mysql", dsn) - err = waitForSqlServerToStart(primaryDatabase) + err = waitForSqlServerToStart(t, primaryDatabase) if err != nil { return -1, nil, err } @@ -955,8 +983,7 @@ func startMySqlServer(dir string) (int, *os.Process, error) { dsn = fmt.Sprintf("root@tcp(127.0.0.1:%v)/", mySqlPort) primaryDatabase = sqlx.MustOpen("mysql", dsn) - os.Chdir(originalCwd) - fmt.Printf("MySQL server started on port %v \n", mySqlPort) + t.Logf("MySQL server started on port %v \n", mySqlPort) return mySqlPort, cmd.Process, nil } @@ -971,43 +998,10 @@ func directoryExists(path string) bool { return info.IsDir() } -var cachedDoltDevBuildPath = "" - -func initializeDevDoltBuild(dir string, goDirPath string) string { - if cachedDoltDevBuildPath != "" { - return cachedDoltDevBuildPath - } - - // If we're not in a CI environment, don't worry about building a dev build - if os.Getenv("CI") != "true" { - return "" - } - - basedir := filepath.Dir(filepath.Dir(dir)) - fullpath := filepath.Join(basedir, fmt.Sprintf("devDolt-%d", os.Getpid())) - - _, err := os.Stat(fullpath) - if err == nil { - return fullpath - } - - fmt.Printf("building dolt dev build at: %s \n", fullpath) - cmd := exec.Command("go", "build", "-o", fullpath, "./cmd/dolt") - cmd.Dir = goDirPath - - output, err := cmd.CombinedOutput() - if err != nil { - panic("unable to build dolt for binlog integration tests: " + err.Error() + "\nFull output: " + string(output) + "\n") - } - cachedDoltDevBuildPath = fullpath - - return cachedDoltDevBuildPath -} - // startDoltSqlServer starts a Dolt sql-server on a free port from the specified directory |dir|. If // |doltPeristentSystemVars| is populated, then those system variables will be set, persistently, for // the Dolt database, before the Dolt sql-server is started. -func startDoltSqlServer(dir string, doltPersistentSystemVars map[string]string) (int, *os.Process, error) { +func startDoltSqlServer(t *testing.T, dir string, doltPersistentSystemVars map[string]string) (int, *os.Process, error) { dir = filepath.Join(dir, "dolt") err := os.MkdirAll(dir, 0777) if err != nil { @@ -1019,57 +1013,34 @@ func startDoltSqlServer(dir string, doltPersistentSystemVars map[string]string) if doltPort < 1 { doltPort = findFreePort() } - fmt.Printf("Starting Dolt sql-server on port: %d, with data dir %s\n", doltPort, dir) - - // take the CWD and move up four directories to find the go directory - if originalWorkingDir == "" { - var err error - originalWorkingDir, err = os.Getwd() - if err != nil { - panic(err) - } - } - goDirPath := filepath.Join(originalWorkingDir, "..", "..", "..", "..") - err = os.Chdir(goDirPath) - if err != nil { - panic(err) - } - - socketPath := filepath.Join("/tmp", fmt.Sprintf("dolt.%v.sock", doltPort)) + t.Logf("Starting Dolt sql-server on port: %d, with data dir %s\n", doltPort, dir) // use an admin user NOT named "root" to test that we don't require the "root" account adminUser := "admin" if doltPersistentSystemVars != nil && len(doltPersistentSystemVars) > 0 { // Initialize the dolt directory first - err = runDoltCommand(dir, goDirPath, "init", "--name=binlog-test", "--email=binlog@test") + err = runDoltCommand(t, dir, "init", "--name=binlog-test", "--email=binlog@test") if err != nil { return -1, nil, err } for systemVar, value := range doltPersistentSystemVars { query := fmt.Sprintf("SET @@PERSIST.%s=%s;", systemVar, value) - err = runDoltCommand(dir, goDirPath, "sql", fmt.Sprintf("-q=%s", query)) + err = runDoltCommand(t, dir, "sql", fmt.Sprintf("-q=%s", query)) if err != nil { return -1, nil, err } } } - args := []string{"go", "run", "./cmd/dolt", + args := []string{DoltDevBuildPath(), "sql-server", fmt.Sprintf("-u%s", adminUser), "--loglevel=TRACE", fmt.Sprintf("--data-dir=%s", dir), - fmt.Sprintf("--port=%v", doltPort), - fmt.Sprintf("--socket=%s", socketPath)} - - // If we're running in CI, use a precompiled dolt binary instead of go run - devDoltPath := initializeDevDoltBuild(dir, goDirPath) - if devDoltPath != "" { - args[2] = devDoltPath - args = args[2:] - } + fmt.Sprintf("--port=%v", doltPort)} + cmd := exec.Command(args[0], args[1:]...) // Set a unique process group ID so that we can cleanly kill this process, as well as @@ -1094,7 +1065,7 @@ func startDoltSqlServer(dir string, doltPersistentSystemVars map[string]string) if err != nil { return -1, nil, err } - fmt.Printf("dolt sql-server logs at: %s \n", doltLogFilePath) + t.Logf("dolt sql-server logs at: %s \n", doltLogFilePath) cmd.Stdout = doltLogFile cmd.Stderr = doltLogFile err = cmd.Start() @@ -1102,18 +1073,18 @@ func startDoltSqlServer(dir string, doltPersistentSystemVars map[string]string) return -1, nil, fmt.Errorf("unable to execute command %v: %v", cmd.String(), err.Error()) } - fmt.Printf("Dolt CMD: %s\n", cmd.String()) + t.Logf("Dolt CMD: %s\n", cmd.String()) dsn := fmt.Sprintf("%s@tcp(127.0.0.1:%v)/", adminUser, doltPort) replicaDatabase = sqlx.MustOpen("mysql", dsn) - err = waitForSqlServerToStart(replicaDatabase) + err = waitForSqlServerToStart(t, replicaDatabase) if err != nil { return -1, nil, err } mustCreateReplicatorUser(replicaDatabase) - fmt.Printf("Dolt server started on port %v \n", doltPort) + t.Logf("Dolt server started on port %v \n", doltPort) return doltPort, cmd.Process, nil } @@ -1125,24 +1096,17 @@ func mustCreateReplicatorUser(db *sqlx.DB) { } // runDoltCommand runs a short-lived dolt CLI command with the specified arguments from |doltArgs|. The Dolt data -// directory is specified from |doltDataDir| and |goDirPath| is the path to the go directory within the Dolt repo. +// directory is specified from |doltDataDir|. // This function will only return when the Dolt CLI command has completed, so it is not suitable for running // long-lived commands such as "sql-server". If the command fails, an error is returned with the combined output. -func runDoltCommand(doltDataDir string, goDirPath string, doltArgs ...string) error { - // If we're running in CI, use a precompiled dolt binary instead of go run - devDoltPath := initializeDevDoltBuild(doltDataDir, goDirPath) - - args := append([]string{"go", "run", "./cmd/dolt", +func runDoltCommand(t *testing.T, doltDataDir string, doltArgs ...string) error { + args := append([]string{DoltDevBuildPath(), fmt.Sprintf("--data-dir=%s", doltDataDir)}, doltArgs...) - if devDoltPath != "" { - args[2] = devDoltPath - args = args[2:] - } cmd := exec.Command(args[0], args[1:]...) - fmt.Printf("Running Dolt CMD: %s\n", cmd.String()) + t.Logf("Running Dolt CMD: %s\n", cmd.String()) output, err := cmd.CombinedOutput() - fmt.Printf("Dolt CMD output: %s\n", string(output)) + t.Logf("Dolt CMD output: %s\n", string(output)) if err != nil { return fmt.Errorf("%w: %s", err, string(output)) } @@ -1152,13 +1116,13 @@ func runDoltCommand(doltDataDir string, goDirPath string, doltArgs ...string) er // waitForSqlServerToStart polls the specified database to wait for it to become available, pausing // between retry attempts, and returning an error if it is not able to verify that the database is // available. -func waitForSqlServerToStart(database *sqlx.DB) error { - fmt.Printf("Waiting for server to start...\n") +func waitForSqlServerToStart(t *testing.T, database *sqlx.DB) error { + t.Logf("Waiting for server to start...\n") for counter := 0; counter < 30; counter++ { if database.Ping() == nil { return nil } - fmt.Printf("not up yet; waiting...\n") + t.Logf("not up yet; waiting...\n") time.Sleep(500 * time.Millisecond) } @@ -1166,10 +1130,10 @@ func waitForSqlServerToStart(database *sqlx.DB) error { } // printFile opens the specified filepath |path| and outputs the contents of that file to stdout. -func printFile(path string) { +func printFile(t *testing.T, path string) { file, err := os.Open(path) if err != nil { - fmt.Printf("Unable to open file: %s \n", err) + t.Logf("Unable to open file: %s \n", err) return } defer file.Close() @@ -1184,9 +1148,9 @@ func printFile(path string) { panic(err) } } - fmt.Print(s) + t.Log(s) } - fmt.Println() + t.Log() } // assertRepoStateFileExists asserts that the repo_state.json file is present for the specified diff --git a/go/libraries/doltcore/sqle/cluster/controller.go b/go/libraries/doltcore/sqle/cluster/controller.go index 3845a1e9fa8..4be3f36b341 100644 --- a/go/libraries/doltcore/sqle/cluster/controller.go +++ b/go/libraries/doltcore/sqle/cluster/controller.go @@ -688,9 +688,14 @@ func (c *Controller) RemoteSrvServerArgs(ctxFactory func(context.Context) (*sql. listenaddr := c.RemoteSrvListenAddr() args.HttpListenAddr = listenaddr args.GrpcListenAddr = listenaddr - args.Options = c.ServerOptions() + ctxInterceptor := sqle.SqlContextServerInterceptor{ + Factory: ctxFactory, + } + args.Options = append(args.Options, ctxInterceptor.Options()...) + args.Options = append(args.Options, c.ServerOptions()...) + args.HttpInterceptor = ctxInterceptor.HTTP(args.HttpInterceptor) var err error - args.FS, args.DBCache, err = sqle.RemoteSrvFSAndDBCache(ctxFactory, sqle.CreateUnknownDatabases) + args.DBCache, err = sqle.RemoteSrvDBCache(sqle.GetInterceptorSqlContext, sqle.CreateUnknownDatabases) if err != nil { return remotesrv.ServerArgs{}, err } @@ -699,7 +704,7 @@ func (c *Controller) RemoteSrvServerArgs(ctxFactory func(context.Context) (*sql. keyID := creds.PubKeyToKID(c.pub) keyIDStr := creds.B32CredsEncoding.EncodeToString(keyID) - args.HttpInterceptor = JWKSHandlerInterceptor(keyIDStr, c.pub) + args.HttpInterceptor = JWKSHandlerInterceptor(args.HttpInterceptor, keyIDStr, c.pub) return args, nil } diff --git a/go/libraries/doltcore/sqle/cluster/jwks.go b/go/libraries/doltcore/sqle/cluster/jwks.go index 1e3c357c1c3..36511ae62be 100644 --- a/go/libraries/doltcore/sqle/cluster/jwks.go +++ b/go/libraries/doltcore/sqle/cluster/jwks.go @@ -46,16 +46,21 @@ func (h JWKSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Write(b) } -func JWKSHandlerInterceptor(keyID string, pub ed25519.PublicKey) func(http.Handler) http.Handler { +func JWKSHandlerInterceptor(existing func(http.Handler) http.Handler, keyID string, pub ed25519.PublicKey) func(http.Handler) http.Handler { jh := JWKSHandler{KeyID: keyID, PublicKey: pub} return func(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + this := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.EscapedPath() == "/.well-known/jwks.json" { jh.ServeHTTP(w, r) return } h.ServeHTTP(w, r) }) + if existing != nil { + return existing(this) + } else { + return this + } } } diff --git a/go/libraries/doltcore/sqle/database_provider.go b/go/libraries/doltcore/sqle/database_provider.go index 2b6b84b96e8..ec6f2caa13f 100644 --- a/go/libraries/doltcore/sqle/database_provider.go +++ b/go/libraries/doltcore/sqle/database_provider.go @@ -684,7 +684,8 @@ func (p *DoltDatabaseProvider) CloneDatabaseFromRemote( if exists { deleteErr := p.fs.Delete(dbName, true) if deleteErr != nil { - err = fmt.Errorf("%s: unable to clean up failed clone in directory '%s'", err.Error(), dbName) + err = fmt.Errorf("%s: unable to clean up failed clone in directory '%s': %s", + err.Error(), dbName, deleteErr.Error()) } } return err diff --git a/go/libraries/doltcore/sqle/dtables/branches_table.go b/go/libraries/doltcore/sqle/dtables/branches_table.go index a343d825d0b..1a32fd0e764 100644 --- a/go/libraries/doltcore/sqle/dtables/branches_table.go +++ b/go/libraries/doltcore/sqle/dtables/branches_table.go @@ -15,6 +15,7 @@ package dtables import ( + "errors" "fmt" "io" @@ -26,6 +27,7 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/schema" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/index" + "github.com/dolthub/dolt/go/store/hash" ) const branchesDefaultRowCount = 10 @@ -90,6 +92,7 @@ func (bt *BranchesTable) Schema() sql.Schema { if !bt.remote { columns = append(columns, &sql.Column{Name: "remote", Type: types.Text, Source: bt.tableName, PrimaryKey: false, Nullable: true}) columns = append(columns, &sql.Column{Name: "branch", Type: types.Text, Source: bt.tableName, PrimaryKey: false, Nullable: true}) + columns = append(columns, &sql.Column{Name: "dirty", Type: types.Boolean, Source: bt.tableName, PrimaryKey: false, Nullable: true}) } return columns } @@ -114,6 +117,7 @@ type BranchItr struct { table *BranchesTable branches []string commits []*doltdb.Commit + dirty []bool idx int } @@ -145,19 +149,28 @@ func NewBranchItr(ctx *sql.Context, table *BranchesTable) (*BranchItr, error) { branchNames := make([]string, len(branchRefs)) commits := make([]*doltdb.Commit, len(branchRefs)) + dirtyBits := make([]bool, len(branchRefs)) for i, branch := range branchRefs { commit, err := ddb.ResolveCommitRefAtRoot(ctx, branch, txRoot) - if err != nil { return nil, err } + var dirty bool + if !remote { + dirty, err = isDirty(ctx, ddb, commit, branch, txRoot) + if err != nil { + return nil, err + } + } + if branch.GetType() == ref.RemoteRefType { branchNames[i] = "remotes/" + branch.GetPath() } else { branchNames[i] = branch.GetPath() } + dirtyBits[i] = dirty commits[i] = commit } @@ -165,6 +178,7 @@ func NewBranchItr(ctx *sql.Context, table *BranchesTable) (*BranchItr, error) { table: table, branches: branchNames, commits: commits, + dirty: dirtyBits, idx: 0, }, nil } @@ -182,6 +196,7 @@ func (itr *BranchItr) Next(ctx *sql.Context) (sql.Row, error) { name := itr.branches[itr.idx] cm := itr.commits[itr.idx] + dirty := itr.dirty[itr.idx] meta, err := cm.GetCommitMeta(ctx) if err != nil { @@ -211,8 +226,53 @@ func (itr *BranchItr) Next(ctx *sql.Context) (sql.Row, error) { remoteName = branch.Remote branchName = branch.Merge.Ref.GetPath() } - return sql.NewRow(name, h.String(), meta.Name, meta.Email, meta.Time(), meta.Description, remoteName, branchName), nil + return sql.NewRow(name, h.String(), meta.Name, meta.Email, meta.Time(), meta.Description, remoteName, branchName, dirty), nil + } +} + +// isDirty returns true if the working ref points to a dirty branch. +func isDirty(ctx *sql.Context, ddb *doltdb.DoltDB, commit *doltdb.Commit, branch ref.DoltRef, txRoot hash.Hash) (bool, error) { + wsRef, err := ref.WorkingSetRefForHead(branch) + if err != nil { + return false, err + } + ws, err := ddb.ResolveWorkingSetAtRoot(ctx, wsRef, txRoot) + if err != nil { + if errors.Is(err, doltdb.ErrWorkingSetNotFound) { + // If there is no working set for this branch, then it is never dirty. This happens on servers commonly. + return false, nil + } + return false, err + } + + workingRoot := ws.WorkingRoot() + workingRootHash, err := workingRoot.HashOf() + if err != nil { + return false, err + } + stagedRoot := ws.StagedRoot() + stagedRootHash, err := stagedRoot.HashOf() + if err != nil { + return false, err + } + + dirty := false + if workingRootHash != stagedRootHash { + dirty = true + } else { + cmRt, err := commit.GetRootValue(ctx) + if err != nil { + return false, err + } + cmRtHash, err := cmRt.HashOf() + if err != nil { + return false, err + } + if cmRtHash != workingRootHash { + dirty = true + } } + return dirty, nil } // Close closes the iterator. diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_harness.go b/go/libraries/doltcore/sqle/enginetest/dolt_harness.go index be5b92cb257..43436848afd 100644 --- a/go/libraries/doltcore/sqle/enginetest/dolt_harness.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_harness.go @@ -515,7 +515,7 @@ func (d *DoltHarness) newProvider() sql.MutableDatabaseProvider { var dEnv *env.DoltEnv if d.useLocalFilesystem { - dEnv = dtestutils.CreateTestEnvForLocalFilesystem() + dEnv = dtestutils.CreateTestEnvForLocalFilesystem(d.t.TempDir()) } else { dEnv = dtestutils.CreateTestEnv() } diff --git a/go/libraries/doltcore/sqle/read_replica_database.go b/go/libraries/doltcore/sqle/read_replica_database.go index fba8ca14cd4..4c757684ea9 100644 --- a/go/libraries/doltcore/sqle/read_replica_database.go +++ b/go/libraries/doltcore/sqle/read_replica_database.go @@ -18,6 +18,7 @@ import ( "context" "errors" "fmt" + "sort" "strings" "sync" @@ -509,9 +510,20 @@ func getReplicationRefs(ctx *sql.Context, rrd ReadReplicaDatabase) ( func refsToDelete(remRefs, localRefs []doltdb.RefWithHash) []doltdb.RefWithHash { toDelete := make([]doltdb.RefWithHash, 0, len(localRefs)) var i, j int + + // Before we map remote refs to local refs to determine which refs to delete, we need to sort them + // by Ref.String() – this ensures a unique identifier that does not conflict with other refs, unlike + // Ref.GetPath(), which can conflict if a branch or tag has the same name. + sort.Slice(remRefs, func(i, j int) bool { + return remRefs[i].Ref.String() < remRefs[j].Ref.String() + }) + sort.Slice(localRefs, func(i, j int) bool { + return localRefs[i].Ref.String() < localRefs[j].Ref.String() + }) + for i < len(remRefs) && j < len(localRefs) { - rem := remRefs[i].Ref.GetPath() - local := localRefs[j].Ref.GetPath() + rem := remRefs[i].Ref.String() + local := localRefs[j].Ref.String() if rem == local { i++ j++ diff --git a/go/libraries/doltcore/sqle/remotesrv.go b/go/libraries/doltcore/sqle/remotesrv.go index 87e6164764e..3d63cb68e74 100644 --- a/go/libraries/doltcore/sqle/remotesrv.go +++ b/go/libraries/doltcore/sqle/remotesrv.go @@ -16,13 +16,15 @@ package sqle import ( "context" + "errors" + "net/http" "github.com/dolthub/go-mysql-server/sql" + "google.golang.org/grpc" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/remotesrv" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" - "github.com/dolthub/dolt/go/libraries/utils/filesys" "github.com/dolthub/dolt/go/store/datas" ) @@ -81,17 +83,12 @@ type CreateUnknownDatabasesSetting bool const CreateUnknownDatabases CreateUnknownDatabasesSetting = true const DoNotCreateUnknownDatabases CreateUnknownDatabasesSetting = false -// Considers |args| and returns a new |remotesrv.ServerArgs| instance which -// will serve databases accessible through |ctxFactory|. -func RemoteSrvFSAndDBCache(ctxFactory func(context.Context) (*sql.Context, error), createSetting CreateUnknownDatabasesSetting) (filesys.Filesys, remotesrv.DBCache, error) { - sqlCtx, err := ctxFactory(context.Background()) - if err != nil { - return nil, nil, err - } - sess := dsess.DSessFromSess(sqlCtx.Session) - fs := sess.Provider().FileSystem() +// Returns a remotesrv.DBCache instance which will use the *sql.Context +// returned from |ctxFactory| to access a database in the session +// DatabaseProvider. +func RemoteSrvDBCache(ctxFactory func(context.Context) (*sql.Context, error), createSetting CreateUnknownDatabasesSetting) (remotesrv.DBCache, error) { dbcache := remotesrvStore{ctxFactory, bool(createSetting)} - return fs, dbcache, nil + return dbcache, nil } func WithUserPasswordAuth(args remotesrv.ServerArgs, authnz remotesrv.AccessControl) remotesrv.ServerArgs { @@ -102,3 +99,88 @@ func WithUserPasswordAuth(args remotesrv.ServerArgs, authnz remotesrv.AccessCont args.Options = append(args.Options, si.Options()...) return args } + +type SqlContextServerInterceptor struct { + Factory func(context.Context) (*sql.Context, error) +} + +type serverStreamWrapper struct { + grpc.ServerStream + ctx context.Context +} + +func (s serverStreamWrapper) Context() context.Context { + return s.ctx +} + +type sqlContextInterceptorKey struct{} + +func GetInterceptorSqlContext(ctx context.Context) (*sql.Context, error) { + if v := ctx.Value(sqlContextInterceptorKey{}); v != nil { + return v.(*sql.Context), nil + } + return nil, errors.New("misconfiguration; a sql.Context should always be available from the interceptor chain.") +} + +func (si SqlContextServerInterceptor) Stream() grpc.StreamServerInterceptor { + return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + sqlCtx, err := si.Factory(ss.Context()) + if err != nil { + return err + } + sql.SessionCommandBegin(sqlCtx.Session) + defer sql.SessionCommandEnd(sqlCtx.Session) + defer sql.SessionEnd(sqlCtx.Session) + newCtx := context.WithValue(ss.Context(), sqlContextInterceptorKey{}, sqlCtx) + newSs := serverStreamWrapper{ + ServerStream: ss, + ctx: newCtx, + } + return handler(srv, newSs) + } +} + +func (si SqlContextServerInterceptor) Unary() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + sqlCtx, err := si.Factory(ctx) + if err != nil { + return nil, err + } + sql.SessionCommandBegin(sqlCtx.Session) + defer sql.SessionCommandEnd(sqlCtx.Session) + defer sql.SessionEnd(sqlCtx.Session) + newCtx := context.WithValue(ctx, sqlContextInterceptorKey{}, sqlCtx) + return handler(newCtx, req) + } +} + +func (si SqlContextServerInterceptor) HTTP(existing func(http.Handler) http.Handler) func(http.Handler) http.Handler { + return func(h http.Handler) http.Handler { + this := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + sqlCtx, err := si.Factory(ctx) + if err != nil { + http.Error(w, "could not initialize sql.Context", http.StatusInternalServerError) + return + } + sql.SessionCommandBegin(sqlCtx.Session) + defer sql.SessionCommandEnd(sqlCtx.Session) + defer sql.SessionEnd(sqlCtx.Session) + newCtx := context.WithValue(ctx, sqlContextInterceptorKey{}, sqlCtx) + newReq := r.WithContext(newCtx) + h.ServeHTTP(w, newReq) + }) + if existing != nil { + return existing(this) + } else { + return this + } + } +} + +func (si SqlContextServerInterceptor) Options() []grpc.ServerOption { + return []grpc.ServerOption{ + grpc.ChainUnaryInterceptor(si.Unary()), + grpc.ChainStreamInterceptor(si.Stream()), + } +} diff --git a/go/libraries/doltcore/sqle/replication_test.go b/go/libraries/doltcore/sqle/replication_test.go index f2922de2147..538db4c3aea 100644 --- a/go/libraries/doltcore/sqle/replication_test.go +++ b/go/libraries/doltcore/sqle/replication_test.go @@ -80,12 +80,17 @@ func TestReplicationBranches(t *testing.T) { local: []string{"feature4", "feature5", "feature6", "feature7", "feature8", "feature9"}, expToDelete: []string{"feature4", "feature5", "feature6", "feature7", "feature8", "feature9"}, }, + { + remote: []string{"main", "new1", "a1"}, + local: []string{"main", "a1"}, + expToDelete: []string{}, + }, } for _, tt := range tests { remoteRefs := make([]doltdb.RefWithHash, len(tt.remote)) for i := range tt.remote { - remoteRefs[i] = doltdb.RefWithHash{Ref: ref.NewRemoteRef("", tt.remote[i])} + remoteRefs[i] = doltdb.RefWithHash{Ref: ref.NewBranchRef(tt.remote[i])} } localRefs := make([]doltdb.RefWithHash, len(tt.local)) for i := range tt.local { @@ -96,6 +101,6 @@ func TestReplicationBranches(t *testing.T) { for i := range diff { diffNames[i] = diff[i].Ref.GetPath() } - assert.Equal(t, diffNames, tt.expToDelete) + assert.Equal(t, tt.expToDelete, diffNames) } } diff --git a/go/libraries/doltcore/sqle/sqlselect_test.go b/go/libraries/doltcore/sqle/sqlselect_test.go index 7f8fc08464d..d065a84f0bb 100644 --- a/go/libraries/doltcore/sqle/sqlselect_test.go +++ b/go/libraries/doltcore/sqle/sqlselect_test.go @@ -784,6 +784,7 @@ func BasicSelectTests() []SelectTest { "Initialize data repository", "", "", + true, // Test setup has a dirty workspace. }, }, ExpectedSqlSchema: sql.Schema{ @@ -795,6 +796,7 @@ func BasicSelectTests() []SelectTest { &sql.Column{Name: "latest_commit_message", Type: gmstypes.Text}, &sql.Column{Name: "remote", Type: gmstypes.Text}, &sql.Column{Name: "branch", Type: gmstypes.Text}, + &sql.Column{Name: "dirty", Type: gmstypes.Boolean}, }, }, } diff --git a/go/libraries/doltcore/sqle/testutil.go b/go/libraries/doltcore/sqle/testutil.go index ee4ad010714..5379fbc77f3 100644 --- a/go/libraries/doltcore/sqle/testutil.go +++ b/go/libraries/doltcore/sqle/testutil.go @@ -452,7 +452,7 @@ func CreateEmptyTestTable(dEnv *env.DoltEnv, tableName string, sch schema.Schema return dEnv.UpdateWorkingRoot(ctx, newRoot) } -// CreateTestDatabase creates a test database with the test data set in it. +// CreateTestDatabase creates a test database with the test data set in it. Has a dirty workspace as well. func CreateTestDatabase() (*env.DoltEnv, error) { ctx := context.Background() dEnv, err := CreateEmptyTestDatabase() diff --git a/go/libraries/events/event_flush_test.go b/go/libraries/events/event_flush_test.go index 1798aba4db6..70934a593de 100644 --- a/go/libraries/events/event_flush_test.go +++ b/go/libraries/events/event_flush_test.go @@ -105,7 +105,7 @@ func TestEventFlushing(t *testing.T) { fs := filesys.LocalFS path := filepath.Join(dPath, evtPath) - dDir := testLib.TestDir(path) + dDir := testLib.TestDir(t.TempDir(), path) ft = createFlushTester(fs, "", dDir) } diff --git a/go/libraries/utils/filesys/fs_test.go b/go/libraries/utils/filesys/fs_test.go index 0982dd79931..a5cfa7ed318 100644 --- a/go/libraries/utils/filesys/fs_test.go +++ b/go/libraries/utils/filesys/fs_test.go @@ -41,8 +41,8 @@ var filesysetmsToTest = map[string]Filesys{ } func TestFilesystems(t *testing.T) { - dir := test.TestDir("filesys_test") - newLocation := test.TestDir("newLocation") + dir := test.TestDir(t.TempDir(), "filesys_test") + newLocation := test.TestDir(t.TempDir(), "newLocation") subdir := filepath.Join(dir, "subdir") subdirFile := filepath.Join(subdir, testSubdirFilename) fp := filepath.Join(dir, testFilename) @@ -186,7 +186,7 @@ func TestNewInMemFS(t *testing.T) { } func TestRecursiveFSIteration(t *testing.T) { - dir := test.TestDir("TestRecursiveFSIteration") + dir := test.TestDir(t.TempDir(), "TestRecursiveFSIteration") for fsName, fs := range filesysetmsToTest { var expectedDirs []string @@ -215,7 +215,7 @@ func TestRecursiveFSIteration(t *testing.T) { } func TestFSIteration(t *testing.T) { - dir := test.TestDir("TestFSIteration") + dir := test.TestDir(t.TempDir(), "TestFSIteration") for fsName, fs := range filesysetmsToTest { var expectedDirs []string @@ -249,7 +249,7 @@ func TestFSIteration(t *testing.T) { } func TestDeletes(t *testing.T) { - dir := test.TestDir("TestDeletes") + dir := test.TestDir(t.TempDir(), "TestDeletes") for fsName, fs := range filesysetmsToTest { var ignored []string diff --git a/go/libraries/utils/test/files.go b/go/libraries/utils/test/files.go index d8561fa9e29..c75105ee5d8 100644 --- a/go/libraries/utils/test/files.go +++ b/go/libraries/utils/test/files.go @@ -22,27 +22,19 @@ import ( ) // TestDir creates a subdirectory inside the systems temp directory -func TestDir(testName string) string { - id, err := uuid.NewRandom() - - if err != nil { - panic(ShouldNeverHappen) - } - - return filepath.Join(os.TempDir(), testName, id.String()) +func TestDir(dir, testName string) string { + return filepath.Join(dir, testName, uuid.NewString()) } // ChangeToTestDir creates a new test directory and changes the current directory to be -func ChangeToTestDir(testName string) (string, error) { - dir := TestDir(testName) +func ChangeToTestDir(tempDir, testName string) (string, error) { + dir := TestDir(tempDir, testName) err := os.MkdirAll(dir, os.ModePerm) - if err != nil { return "", err } err = os.Chdir(dir) - if err != nil { return "", err } diff --git a/go/libraries/utils/test/test_test.go b/go/libraries/utils/test/test_test.go index a5c040346e9..a486f794f53 100644 --- a/go/libraries/utils/test/test_test.go +++ b/go/libraries/utils/test/test_test.go @@ -24,7 +24,7 @@ import ( // test your tests so you can test while you test func TestLDTestUtils(t *testing.T) { - dir, err := ChangeToTestDir("TestLDTestUtils") + dir, err := ChangeToTestDir(t.TempDir(), "TestLDTestUtils") if err != nil { t.Fatal("Couldn't change to test dir") diff --git a/go/performance/import_benchmarker/cmd/main.go b/go/performance/import_benchmarker/cmd/main.go index 71ad9eccf98..728ea3f9584 100644 --- a/go/performance/import_benchmarker/cmd/main.go +++ b/go/performance/import_benchmarker/cmd/main.go @@ -42,9 +42,16 @@ func main() { if err != nil { log.Fatalln(err) } + defer os.RemoveAll(tmpdir) + + userdir, err := os.MkdirTemp("", "import-benchmarker-") + if err != nil { + log.Fatalln(err) + } + defer os.RemoveAll(userdir) results := new(ib.ImportResults) - u, err := driver.NewDoltUser() + u, err := driver.NewDoltUser(userdir) for _, test := range def.Tests { test.Results = results test.InitWithTmpDir(tmpdir) @@ -73,5 +80,4 @@ func main() { } else { fmt.Println(results.SqlDump()) } - os.Exit(0) } diff --git a/go/performance/import_benchmarker/testdef.go b/go/performance/import_benchmarker/testdef.go index 3bc352cfc10..b0c42d500bd 100644 --- a/go/performance/import_benchmarker/testdef.go +++ b/go/performance/import_benchmarker/testdef.go @@ -214,7 +214,7 @@ func (test *ImportTest) Run(t *testing.T) { test.InitWithTmpDir(tmp) } - u, err := driver.NewDoltUser() + u, err := driver.NewDoltUser(t.TempDir()) for _, r := range test.Repos { if r.ExternalServer != nil { err := test.RunExternalServerTests(r.Name, r.ExternalServer) diff --git a/go/performance/sysbench/cmd/main.go b/go/performance/sysbench/cmd/main.go index 77f6769799e..caf493e6aa2 100644 --- a/go/performance/sysbench/cmd/main.go +++ b/go/performance/sysbench/cmd/main.go @@ -52,9 +52,16 @@ func main() { if err != nil { log.Fatalln(err) } + defer os.RemoveAll(tmpdir) + + userdir, err := os.MkdirTemp("", "sysbench-user-dir_") + if err != nil { + log.Fatalln(err) + } + defer os.RemoveAll(userdir) results := new(sysbench.Results) - u, err := driver.NewDoltUser() + u, err := driver.NewDoltUser(userdir) for _, test := range defs.Tests { test.InitWithTmpDir(tmpdir) @@ -83,5 +90,4 @@ func main() { } else { fmt.Println(results.SqlDump()) } - os.Exit(0) } diff --git a/go/performance/sysbench/testdef.go b/go/performance/sysbench/testdef.go index 5fc910d2524..7b3d3d71bdd 100644 --- a/go/performance/sysbench/testdef.go +++ b/go/performance/sysbench/testdef.go @@ -440,7 +440,7 @@ func (test *Script) Run(t *testing.T) { } results := new(Results) - u, err := driver.NewDoltUser() + u, err := driver.NewDoltUser(t.TempDir()) test.Results = results test.InitWithTmpDir(tmpdir) for _, r := range test.Repos { diff --git a/go/store/chunks/chunk_store.go b/go/store/chunks/chunk_store.go index 19bf4d07f96..bb5fe17a162 100644 --- a/go/store/chunks/chunk_store.go +++ b/go/store/chunks/chunk_store.go @@ -225,11 +225,11 @@ type ChunkStoreGarbageCollector interface { // // This function should not block indefinitely and should return an // error if a GC is already in progress. - BeginGC(addChunk func(hash.Hash) bool) error + BeginGC(addChunk func(hash.Hash) bool, mode GCMode) error // EndGC indicates that the GC is over. The previously provided // addChunk function must not be called after this function. - EndGC() + EndGC(mode GCMode) // MarkAndSweepChunks returns a handle that can be used to supply // hashes which should be saved into |dest|. The hashes are @@ -257,6 +257,12 @@ type GenerationalCS interface { NewGen() ChunkStoreGarbageCollector OldGen() ChunkStoreGarbageCollector GhostGen() ChunkStore + + // Has the same return values as OldGen().HasMany, but should be used by a + // generational GC process as the filter function instead of + // OldGen().HasMany. This function never takes read dependencies on the + // chunks that it queries. + OldGenGCFilter() HasManyFunc } var ErrUnsupportedOperation = errors.New("operation not supported") diff --git a/go/store/chunks/memory_store.go b/go/store/chunks/memory_store.go index a7fd1ae5725..ec664b03748 100644 --- a/go/store/chunks/memory_store.go +++ b/go/store/chunks/memory_store.go @@ -335,11 +335,11 @@ func (ms *MemoryStoreView) Commit(ctx context.Context, current, last hash.Hash) return success, nil } -func (ms *MemoryStoreView) BeginGC(keeper func(hash.Hash) bool) error { +func (ms *MemoryStoreView) BeginGC(keeper func(hash.Hash) bool, _ GCMode) error { return ms.transitionToGC(keeper) } -func (ms *MemoryStoreView) EndGC() { +func (ms *MemoryStoreView) EndGC(_ GCMode) { ms.transitionToNoGC() } diff --git a/go/store/chunks/test_utils.go b/go/store/chunks/test_utils.go index 36e7467bbb6..084166caaa1 100644 --- a/go/store/chunks/test_utils.go +++ b/go/store/chunks/test_utils.go @@ -75,20 +75,20 @@ func (s *TestStoreView) Put(ctx context.Context, c Chunk, getAddrs GetAddrsCurry return s.ChunkStore.Put(ctx, c, getAddrs) } -func (s *TestStoreView) BeginGC(keeper func(hash.Hash) bool) error { +func (s *TestStoreView) BeginGC(keeper func(hash.Hash) bool, mode GCMode) error { collector, ok := s.ChunkStore.(ChunkStoreGarbageCollector) if !ok { return ErrUnsupportedOperation } - return collector.BeginGC(keeper) + return collector.BeginGC(keeper, mode) } -func (s *TestStoreView) EndGC() { +func (s *TestStoreView) EndGC(mode GCMode) { collector, ok := s.ChunkStore.(ChunkStoreGarbageCollector) if !ok { panic(ErrUnsupportedOperation) } - collector.EndGC() + collector.EndGC(mode) } func (s *TestStoreView) MarkAndSweepChunks(ctx context.Context, getAddrs GetAddrsCurry, filter HasManyFunc, dest ChunkStore, mode GCMode) (MarkAndSweeper, error) { diff --git a/go/store/nbs/archive_build.go b/go/store/nbs/archive_build.go index 72f15d94f04..349bf239d11 100644 --- a/go/store/nbs/archive_build.go +++ b/go/store/nbs/archive_build.go @@ -425,7 +425,7 @@ func gatherAllChunks(ctx context.Context, cs chunkSource, idx tableIndex, stats return nil, nil, err } - bytes, err := cs.get(ctx, h, stats) + bytes, _, err := cs.get(ctx, h, nil, stats) if err != nil { return nil, nil, err } @@ -907,7 +907,7 @@ func (csc *simpleChunkSourceCache) get(ctx context.Context, h hash.Hash, stats * return chk, nil } - bytes, err := csc.cs.get(ctx, h, stats) + bytes, _, err := csc.cs.get(ctx, h, nil, stats) if bytes == nil || err != nil { return nil, err } @@ -919,7 +919,8 @@ func (csc *simpleChunkSourceCache) get(ctx context.Context, h hash.Hash, stats * // has returns true if the chunk is in the ChunkSource. This is not related to what is cached, just a helper. func (csc *simpleChunkSourceCache) has(h hash.Hash) (bool, error) { - return csc.cs.has(h) + res, _, err := csc.cs.has(h, nil) + return res, err } // addresses get all chunk addresses of the ChunkSource as a hash.HashSet. diff --git a/go/store/nbs/archive_chunk_source.go b/go/store/nbs/archive_chunk_source.go index 3ccd0183c57..1b1f55d8265 100644 --- a/go/store/nbs/archive_chunk_source.go +++ b/go/store/nbs/archive_chunk_source.go @@ -64,42 +64,63 @@ func openReader(file string) (io.ReaderAt, uint64, error) { return f, uint64(stat.Size()), nil } -func (acs archiveChunkSource) has(h hash.Hash) (bool, error) { - return acs.aRdr.has(h), nil +func (acs archiveChunkSource) has(h hash.Hash, keeper keeperF) (bool, gcBehavior, error) { + res := acs.aRdr.has(h) + if res && keeper != nil && keeper(h) { + return false, gcBehavior_Block, nil + } + return res, gcBehavior_Continue, nil } -func (acs archiveChunkSource) hasMany(addrs []hasRecord) (bool, error) { +func (acs archiveChunkSource) hasMany(addrs []hasRecord, keeper keeperF) (bool, gcBehavior, error) { // single threaded first pass. foundAll := true for i, addr := range addrs { - if acs.aRdr.has(*(addr.a)) { + h := *addr.a + if acs.aRdr.has(h) { + if keeper != nil && keeper(h) { + return false, gcBehavior_Block, nil + } addrs[i].has = true } else { foundAll = false } } - return !foundAll, nil + return !foundAll, gcBehavior_Continue, nil } -func (acs archiveChunkSource) get(ctx context.Context, h hash.Hash, stats *Stats) ([]byte, error) { - // ctx, stats ? NM4. - return acs.aRdr.get(h) +func (acs archiveChunkSource) get(ctx context.Context, h hash.Hash, keeper keeperF, stats *Stats) ([]byte, gcBehavior, error) { + res, err := acs.aRdr.get(h) + if err != nil { + return nil, gcBehavior_Continue, err + } + if res != nil && keeper != nil && keeper(h) { + return nil, gcBehavior_Block, nil + } + return res, gcBehavior_Continue, nil } -func (acs archiveChunkSource) getMany(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), stats *Stats) (bool, error) { +func (acs archiveChunkSource) getMany(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), keeper keeperF, stats *Stats) (bool, gcBehavior, error) { // single threaded first pass. foundAll := true for i, req := range reqs { - data, err := acs.aRdr.get(*req.a) - if err != nil || data == nil { + h := *req.a + data, err := acs.aRdr.get(h) + if err != nil { + return true, gcBehavior_Continue, err + } + if data == nil { foundAll = false } else { + if keeper != nil && keeper(h) { + return true, gcBehavior_Block, nil + } chunk := chunks.NewChunk(data) found(ctx, &chunk) reqs[i].found = true } } - return !foundAll, nil + return !foundAll, gcBehavior_Continue, nil } // iterate iterates over the archive chunks. The callback is called for each chunk in the archive. This is not optimized @@ -146,14 +167,14 @@ func (acs archiveChunkSource) clone() (chunkSource, error) { return archiveChunkSource{acs.file, rdr}, nil } -func (acs archiveChunkSource) getRecordRanges(_ context.Context, _ []getRecord) (map[hash.Hash]Range, error) { - return nil, errors.New("Archive chunk source does not support getRecordRanges") +func (acs archiveChunkSource) getRecordRanges(_ context.Context, _ []getRecord, _ keeperF) (map[hash.Hash]Range, gcBehavior, error) { + return nil, gcBehavior_Continue, errors.New("Archive chunk source does not support getRecordRanges") } -func (acs archiveChunkSource) getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), stats *Stats) (bool, error) { +func (acs archiveChunkSource) getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), keeper keeperF, stats *Stats) (bool, gcBehavior, error) { return acs.getMany(ctx, eg, reqs, func(ctx context.Context, chk *chunks.Chunk) { found(ctx, ChunkToCompressedChunk(*chk)) - }, stats) + }, keeper, stats) } func (acs archiveChunkSource) iterateAllChunks(ctx context.Context, cb func(chunks.Chunk)) error { diff --git a/go/store/nbs/archive_test.go b/go/store/nbs/archive_test.go index da3c324cf2a..c78d3b5710e 100644 --- a/go/store/nbs/archive_test.go +++ b/go/store/nbs/archive_test.go @@ -655,28 +655,28 @@ type testChunkSource struct { var _ chunkSource = (*testChunkSource)(nil) -func (tcs *testChunkSource) get(_ context.Context, h hash.Hash, _ *Stats) ([]byte, error) { +func (tcs *testChunkSource) get(_ context.Context, h hash.Hash, _ keeperF, _ *Stats) ([]byte, gcBehavior, error) { for _, chk := range tcs.chunks { if chk.Hash() == h { - return chk.Data(), nil + return chk.Data(), gcBehavior_Continue, nil } } - return nil, errors.New("not found") + return nil, gcBehavior_Continue, errors.New("not found") } -func (tcs *testChunkSource) has(h hash.Hash) (bool, error) { +func (tcs *testChunkSource) has(h hash.Hash, keeper keeperF) (bool, gcBehavior, error) { panic("never used") } -func (tcs *testChunkSource) hasMany(addrs []hasRecord) (bool, error) { +func (tcs *testChunkSource) hasMany(addrs []hasRecord, keeper keeperF) (bool, gcBehavior, error) { panic("never used") } -func (tcs *testChunkSource) getMany(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), stats *Stats) (bool, error) { +func (tcs *testChunkSource) getMany(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), keeper keeperF, stats *Stats) (bool, gcBehavior, error) { panic("never used") } -func (tcs *testChunkSource) getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), stats *Stats) (bool, error) { +func (tcs *testChunkSource) getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), keeper keeperF, stats *Stats) (bool, gcBehavior, error) { panic("never used") } @@ -700,7 +700,7 @@ func (tcs *testChunkSource) reader(ctx context.Context) (io.ReadCloser, uint64, panic("never used") } -func (tcs *testChunkSource) getRecordRanges(ctx context.Context, requests []getRecord) (map[hash.Hash]Range, error) { +func (tcs *testChunkSource) getRecordRanges(ctx context.Context, requests []getRecord, keeper keeperF) (map[hash.Hash]Range, gcBehavior, error) { panic("never used") } diff --git a/go/store/nbs/aws_table_persister.go b/go/store/nbs/aws_table_persister.go index cc58ffea894..816a9314620 100644 --- a/go/store/nbs/aws_table_persister.go +++ b/go/store/nbs/aws_table_persister.go @@ -115,25 +115,31 @@ func (s3p awsTablePersister) key(k string) string { return k } -func (s3p awsTablePersister) Persist(ctx context.Context, mt *memTable, haver chunkReader, stats *Stats) (chunkSource, error) { - name, data, chunkCount, err := mt.write(haver, stats) - +func (s3p awsTablePersister) Persist(ctx context.Context, mt *memTable, haver chunkReader, keeper keeperF, stats *Stats) (chunkSource, gcBehavior, error) { + name, data, chunkCount, gcb, err := mt.write(haver, keeper, stats) if err != nil { - return emptyChunkSource{}, err + return emptyChunkSource{}, gcBehavior_Continue, err + } + if gcb != gcBehavior_Continue { + return emptyChunkSource{}, gcb, nil } if chunkCount == 0 { - return emptyChunkSource{}, nil + return emptyChunkSource{}, gcBehavior_Continue, nil } err = s3p.multipartUpload(ctx, bytes.NewReader(data), uint64(len(data)), name.String()) if err != nil { - return emptyChunkSource{}, err + return emptyChunkSource{}, gcBehavior_Continue, err } tra := &s3TableReaderAt{&s3ObjectReader{s3: s3p.s3, bucket: s3p.bucket, readRl: s3p.rl, ns: s3p.ns}, name} - return newReaderFromIndexData(ctx, s3p.q, data, name, tra, s3BlockSize) + src, err := newReaderFromIndexData(ctx, s3p.q, data, name, tra, s3BlockSize) + if err != nil { + return emptyChunkSource{}, gcBehavior_Continue, err + } + return src, gcBehavior_Continue, nil } func (s3p awsTablePersister) multipartUpload(ctx context.Context, r io.Reader, sz uint64, key string) error { diff --git a/go/store/nbs/aws_table_persister_test.go b/go/store/nbs/aws_table_persister_test.go index 4ab92c1651b..3187f2e08b6 100644 --- a/go/store/nbs/aws_table_persister_test.go +++ b/go/store/nbs/aws_table_persister_test.go @@ -90,7 +90,7 @@ func TestAWSTablePersisterPersist(t *testing.T) { s3svc := makeFakeS3(t) s3p := awsTablePersister{s3: s3svc, bucket: "bucket", limits: limits5mb, ns: ns, q: &UnlimitedQuotaProvider{}} - src, err := s3p.Persist(context.Background(), mt, nil, &Stats{}) + src, _, err := s3p.Persist(context.Background(), mt, nil, nil, &Stats{}) require.NoError(t, err) defer src.close() @@ -108,7 +108,7 @@ func TestAWSTablePersisterPersist(t *testing.T) { s3svc := makeFakeS3(t) s3p := awsTablePersister{s3: s3svc, bucket: "bucket", limits: limits64mb, ns: ns, q: &UnlimitedQuotaProvider{}} - src, err := s3p.Persist(context.Background(), mt, nil, &Stats{}) + src, _, err := s3p.Persist(context.Background(), mt, nil, nil, &Stats{}) require.NoError(t, err) defer src.close() if assert.True(mustUint32(src.count()) > 0) { @@ -133,7 +133,7 @@ func TestAWSTablePersisterPersist(t *testing.T) { s3svc := makeFakeS3(t) s3p := awsTablePersister{s3: s3svc, bucket: "bucket", limits: limits5mb, ns: ns, q: &UnlimitedQuotaProvider{}} - src, err := s3p.Persist(context.Background(), mt, existingTable, &Stats{}) + src, _, err := s3p.Persist(context.Background(), mt, existingTable, nil, &Stats{}) require.NoError(t, err) defer src.close() assert.True(mustUint32(src.count()) == 0) @@ -148,7 +148,7 @@ func TestAWSTablePersisterPersist(t *testing.T) { s3svc := &failingFakeS3{makeFakeS3(t), sync.Mutex{}, 1} s3p := awsTablePersister{s3: s3svc, bucket: "bucket", limits: limits5mb, ns: ns, q: &UnlimitedQuotaProvider{}} - _, err := s3p.Persist(context.Background(), mt, nil, &Stats{}) + _, _, err := s3p.Persist(context.Background(), mt, nil, nil, &Stats{}) assert.Error(err) }) } @@ -306,7 +306,7 @@ func TestAWSTablePersisterConjoinAll(t *testing.T) { for i := 0; i < len(chunks); i++ { mt := newMemTable(uint64(2 * targetPartSize)) mt.addChunk(computeAddr(chunks[i]), chunks[i]) - cs, err := s3p.Persist(context.Background(), mt, nil, &Stats{}) + cs, _, err := s3p.Persist(context.Background(), mt, nil, nil, &Stats{}) require.NoError(t, err) sources = append(sources, cs) } @@ -379,7 +379,7 @@ func TestAWSTablePersisterConjoinAll(t *testing.T) { } var err error - sources[i], err = s3p.Persist(context.Background(), mt, nil, &Stats{}) + sources[i], _, err = s3p.Persist(context.Background(), mt, nil, nil, &Stats{}) require.NoError(t, err) } src, _, err := s3p.ConjoinAll(context.Background(), sources, &Stats{}) @@ -417,9 +417,9 @@ func TestAWSTablePersisterConjoinAll(t *testing.T) { rand.Read(medChunks[i]) mt.addChunk(computeAddr(medChunks[i]), medChunks[i]) } - cs1, err := s3p.Persist(context.Background(), mt, nil, &Stats{}) + cs1, _, err := s3p.Persist(context.Background(), mt, nil, nil, &Stats{}) require.NoError(t, err) - cs2, err := s3p.Persist(context.Background(), mtb, nil, &Stats{}) + cs2, _, err := s3p.Persist(context.Background(), mtb, nil, nil, &Stats{}) require.NoError(t, err) sources := chunkSources{cs1, cs2} @@ -450,7 +450,7 @@ func TestAWSTablePersisterConjoinAll(t *testing.T) { mt := newMemTable(uint64(2 * targetPartSize)) mt.addChunk(computeAddr(smallChunks[i]), smallChunks[i]) var err error - sources[i], err = s3p.Persist(context.Background(), mt, nil, &Stats{}) + sources[i], _, err = s3p.Persist(context.Background(), mt, nil, nil, &Stats{}) require.NoError(t, err) } @@ -461,7 +461,7 @@ func TestAWSTablePersisterConjoinAll(t *testing.T) { } var err error - cs, err := s3p.Persist(context.Background(), mt, nil, &Stats{}) + cs, _, err := s3p.Persist(context.Background(), mt, nil, nil, &Stats{}) require.NoError(t, err) sources = append(sources, cs) @@ -474,7 +474,7 @@ func TestAWSTablePersisterConjoinAll(t *testing.T) { mt.addChunk(computeAddr(medChunks[i]), medChunks[i]) } - cs, err = s3p.Persist(context.Background(), mt, nil, &Stats{}) + cs, _, err = s3p.Persist(context.Background(), mt, nil, nil, &Stats{}) require.NoError(t, err) sources = append(sources, cs) diff --git a/go/store/nbs/bs_persister.go b/go/store/nbs/bs_persister.go index 9aca6ecd73b..274bbcef2e5 100644 --- a/go/store/nbs/bs_persister.go +++ b/go/store/nbs/bs_persister.go @@ -45,12 +45,16 @@ var _ tableFilePersister = &blobstorePersister{} // Persist makes the contents of mt durable. Chunks already present in // |haver| may be dropped in the process. -func (bsp *blobstorePersister) Persist(ctx context.Context, mt *memTable, haver chunkReader, stats *Stats) (chunkSource, error) { - address, data, chunkCount, err := mt.write(haver, stats) +func (bsp *blobstorePersister) Persist(ctx context.Context, mt *memTable, haver chunkReader, keeper keeperF, stats *Stats) (chunkSource, gcBehavior, error) { + address, data, chunkCount, gcb, err := mt.write(haver, keeper, stats) if err != nil { - return emptyChunkSource{}, err - } else if chunkCount == 0 { - return emptyChunkSource{}, nil + return emptyChunkSource{}, gcBehavior_Continue, err + } + if gcb != gcBehavior_Continue { + return emptyChunkSource{}, gcb, nil + } + if chunkCount == 0 { + return emptyChunkSource{}, gcBehavior_Continue, nil } name := address.String() @@ -59,24 +63,28 @@ func (bsp *blobstorePersister) Persist(ctx context.Context, mt *memTable, haver // first write table records and tail (index+footer) as separate blobs eg, ectx := errgroup.WithContext(ctx) - eg.Go(func() (err error) { - _, err = bsp.bs.Put(ectx, name+tableRecordsExt, int64(len(records)), bytes.NewBuffer(records)) - return + eg.Go(func() error { + _, err := bsp.bs.Put(ectx, name+tableRecordsExt, int64(len(records)), bytes.NewBuffer(records)) + return err }) - eg.Go(func() (err error) { - _, err = bsp.bs.Put(ectx, name+tableTailExt, int64(len(tail)), bytes.NewBuffer(tail)) - return + eg.Go(func() error { + _, err := bsp.bs.Put(ectx, name+tableTailExt, int64(len(tail)), bytes.NewBuffer(tail)) + return err }) if err = eg.Wait(); err != nil { - return nil, err + return nil, gcBehavior_Continue, err } // then concatenate into a final blob if _, err = bsp.bs.Concatenate(ctx, name, []string{name + tableRecordsExt, name + tableTailExt}); err != nil { - return emptyChunkSource{}, err + return emptyChunkSource{}, gcBehavior_Continue, err } rdr := &bsTableReaderAt{name, bsp.bs} - return newReaderFromIndexData(ctx, bsp.q, data, address, rdr, bsp.blockSize) + src, err := newReaderFromIndexData(ctx, bsp.q, data, address, rdr, bsp.blockSize) + if err != nil { + return emptyChunkSource{}, gcBehavior_Continue, err + } + return src, gcBehavior_Continue, nil } // ConjoinAll implements tablePersister. diff --git a/go/store/nbs/byte_sink_test.go b/go/store/nbs/byte_sink_test.go index d388f2d42df..7717e32a253 100644 --- a/go/store/nbs/byte_sink_test.go +++ b/go/store/nbs/byte_sink_test.go @@ -27,7 +27,7 @@ import ( ) func TestBlockBufferTableSink(t *testing.T) { - createSink := func() ByteSink { + createSink := func(*testing.T) ByteSink { return NewBlockBufferByteSink(128) } @@ -35,7 +35,7 @@ func TestBlockBufferTableSink(t *testing.T) { } func TestFixedBufferTableSink(t *testing.T) { - createSink := func() ByteSink { + createSink := func(*testing.T) ByteSink { return NewFixedBufferByteSink(make([]byte, 32*1024)) } @@ -43,8 +43,8 @@ func TestFixedBufferTableSink(t *testing.T) { } func TestBufferedFileByteSink(t *testing.T) { - createSink := func() ByteSink { - sink, err := NewBufferedFileByteSink("", 4*1024, 16) + createSink := func(t *testing.T) ByteSink { + sink, err := NewBufferedFileByteSink(t.TempDir(), 4*1024, 16) require.NoError(t, err) return sink @@ -53,7 +53,7 @@ func TestBufferedFileByteSink(t *testing.T) { suite.Run(t, &TableSinkSuite{createSink, t}) t.Run("ReaderTwice", func(t *testing.T) { - sink, err := NewBufferedFileByteSink("", 4*1024, 16) + sink, err := NewBufferedFileByteSink(t.TempDir(), 4*1024, 16) require.NoError(t, err) _, err = sink.Write([]byte{1, 2, 3, 4}) require.NoError(t, err) @@ -76,7 +76,7 @@ func TestBufferedFileByteSink(t *testing.T) { } type TableSinkSuite struct { - sinkFactory func() ByteSink + sinkFactory func(*testing.T) ByteSink t *testing.T } @@ -116,7 +116,7 @@ func verifyContents(t *testing.T, bytes []byte) { } func (suite *TableSinkSuite) TestWriteAndFlush() { - sink := suite.sinkFactory() + sink := suite.sinkFactory(suite.t) err := writeToSink(sink) require.NoError(suite.t, err) @@ -128,7 +128,7 @@ func (suite *TableSinkSuite) TestWriteAndFlush() { } func (suite *TableSinkSuite) TestWriteAndFlushToFile() { - sink := suite.sinkFactory() + sink := suite.sinkFactory(suite.t) err := writeToSink(sink) require.NoError(suite.t, err) diff --git a/go/store/nbs/cmp_chunk_table_writer_test.go b/go/store/nbs/cmp_chunk_table_writer_test.go index 170cc43cb64..33323d018b0 100644 --- a/go/store/nbs/cmp_chunk_table_writer_test.go +++ b/go/store/nbs/cmp_chunk_table_writer_test.go @@ -51,12 +51,12 @@ func TestCmpChunkTableWriter(t *testing.T) { found := make([]CompressedChunk, 0) eg, egCtx := errgroup.WithContext(ctx) - _, err = tr.getManyCompressed(egCtx, eg, reqs, func(ctx context.Context, c CompressedChunk) { found = append(found, c) }, &Stats{}) + _, _, err = tr.getManyCompressed(egCtx, eg, reqs, func(ctx context.Context, c CompressedChunk) { found = append(found, c) }, nil, &Stats{}) require.NoError(t, err) require.NoError(t, eg.Wait()) // for all the chunks we find, write them using the compressed writer - tw, err := NewCmpChunkTableWriter("") + tw, err := NewCmpChunkTableWriter(t.TempDir()) require.NoError(t, err) for _, cmpChnk := range found { err = tw.AddCmpChunk(cmpChnk) @@ -67,7 +67,7 @@ func TestCmpChunkTableWriter(t *testing.T) { require.NoError(t, err) t.Run("ErrDuplicateChunkWritten", func(t *testing.T) { - tw, err := NewCmpChunkTableWriter("") + tw, err := NewCmpChunkTableWriter(t.TempDir()) require.NoError(t, err) for _, cmpChnk := range found { err = tw.AddCmpChunk(cmpChnk) @@ -96,7 +96,7 @@ func TestCmpChunkTableWriter(t *testing.T) { } func TestCmpChunkTableWriterGhostChunk(t *testing.T) { - tw, err := NewCmpChunkTableWriter("") + tw, err := NewCmpChunkTableWriter(t.TempDir()) require.NoError(t, err) require.Error(t, tw.AddCmpChunk(NewGhostCompressedChunk(hash.Parse("6af71afc2ea0hmp4olev0vp9q1q5gvb1")))) } @@ -146,7 +146,7 @@ func readAllChunks(ctx context.Context, hashes hash.HashSet, reader tableReader) reqs := toGetRecords(hashes) found := make([]*chunks.Chunk, 0) eg, ctx := errgroup.WithContext(ctx) - _, err := reader.getMany(ctx, eg, reqs, func(ctx context.Context, c *chunks.Chunk) { found = append(found, c) }, &Stats{}) + _, _, err := reader.getMany(ctx, eg, reqs, func(ctx context.Context, c *chunks.Chunk) { found = append(found, c) }, nil, &Stats{}) if err != nil { return nil, err } diff --git a/go/store/nbs/conjoiner_test.go b/go/store/nbs/conjoiner_test.go index 846aa4c7f4a..a9f64aa2220 100644 --- a/go/store/nbs/conjoiner_test.go +++ b/go/store/nbs/conjoiner_test.go @@ -63,7 +63,7 @@ func makeTestSrcs(t *testing.T, tableSizes []uint32, p tablePersister) (srcs chu c := nextChunk() mt.addChunk(computeAddr(c), c) } - cs, err := p.Persist(context.Background(), mt, nil, &Stats{}) + cs, _, err := p.Persist(context.Background(), mt, nil, nil, &Stats{}) require.NoError(t, err) c, err := cs.clone() require.NoError(t, err) @@ -159,11 +159,11 @@ func testConjoin(t *testing.T, factory func(t *testing.T) tablePersister) { var ok bool for _, act := range actualSrcs { var err error - ok, err = act.has(rec.a) + ok, _, err = act.has(rec.a, nil) require.NoError(t, err) var buf []byte if ok { - buf, err = act.get(ctx, rec.a, stats) + buf, _, err = act.get(ctx, rec.a, nil, stats) require.NoError(t, err) assert.Equal(t, rec.data, buf) break @@ -180,7 +180,7 @@ func testConjoin(t *testing.T, factory func(t *testing.T) tablePersister) { mt := newMemTable(testMemTableSize) data := []byte{0xde, 0xad} mt.addChunk(computeAddr(data), data) - src, err := p.Persist(context.Background(), mt, nil, &Stats{}) + src, _, err := p.Persist(context.Background(), mt, nil, nil, &Stats{}) require.NoError(t, err) defer src.close() return tableSpec{src.hash(), mustUint32(src.count())} diff --git a/go/store/nbs/empty_chunk_source.go b/go/store/nbs/empty_chunk_source.go index 5df2696c33d..8d00c820de4 100644 --- a/go/store/nbs/empty_chunk_source.go +++ b/go/store/nbs/empty_chunk_source.go @@ -34,24 +34,24 @@ import ( type emptyChunkSource struct{} -func (ecs emptyChunkSource) has(h hash.Hash) (bool, error) { - return false, nil +func (ecs emptyChunkSource) has(h hash.Hash, _ keeperF) (bool, gcBehavior, error) { + return false, gcBehavior_Continue, nil } -func (ecs emptyChunkSource) hasMany(addrs []hasRecord) (bool, error) { - return true, nil +func (ecs emptyChunkSource) hasMany(addrs []hasRecord, _ keeperF) (bool, gcBehavior, error) { + return true, gcBehavior_Continue, nil } -func (ecs emptyChunkSource) get(ctx context.Context, h hash.Hash, stats *Stats) ([]byte, error) { - return nil, nil +func (ecs emptyChunkSource) get(ctx context.Context, h hash.Hash, keeper keeperF, stats *Stats) ([]byte, gcBehavior, error) { + return nil, gcBehavior_Continue, nil } -func (ecs emptyChunkSource) getMany(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), stats *Stats) (bool, error) { - return true, nil +func (ecs emptyChunkSource) getMany(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), keeper keeperF, stats *Stats) (bool, gcBehavior, error) { + return true, gcBehavior_Continue, nil } -func (ecs emptyChunkSource) getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), stats *Stats) (bool, error) { - return true, nil +func (ecs emptyChunkSource) getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), keeper keeperF, stats *Stats) (bool, gcBehavior, error) { + return true, gcBehavior_Continue, nil } func (ecs emptyChunkSource) count() (uint32, error) { @@ -74,8 +74,8 @@ func (ecs emptyChunkSource) reader(context.Context) (io.ReadCloser, uint64, erro return io.NopCloser(&bytes.Buffer{}), 0, nil } -func (ecs emptyChunkSource) getRecordRanges(ctx context.Context, requests []getRecord) (map[hash.Hash]Range, error) { - return map[hash.Hash]Range{}, nil +func (ecs emptyChunkSource) getRecordRanges(ctx context.Context, requests []getRecord, keeper keeperF) (map[hash.Hash]Range, gcBehavior, error) { + return map[hash.Hash]Range{}, gcBehavior_Continue, nil } func (ecs emptyChunkSource) currentSize() uint64 { diff --git a/go/store/nbs/file_table_persister.go b/go/store/nbs/file_table_persister.go index 5868b982174..83175088f29 100644 --- a/go/store/nbs/file_table_persister.go +++ b/go/store/nbs/file_table_persister.go @@ -86,16 +86,23 @@ func (ftp *fsTablePersister) Exists(ctx context.Context, name hash.Hash, chunkCo return archiveFileExists(ctx, ftp.dir, name) } -func (ftp *fsTablePersister) Persist(ctx context.Context, mt *memTable, haver chunkReader, stats *Stats) (chunkSource, error) { +func (ftp *fsTablePersister) Persist(ctx context.Context, mt *memTable, haver chunkReader, keeper keeperF, stats *Stats) (chunkSource, gcBehavior, error) { t1 := time.Now() defer stats.PersistLatency.SampleTimeSince(t1) - name, data, chunkCount, err := mt.write(haver, stats) + name, data, chunkCount, gcb, err := mt.write(haver, keeper, stats) if err != nil { - return emptyChunkSource{}, err + return emptyChunkSource{}, gcBehavior_Continue, err + } + if gcb != gcBehavior_Continue { + return emptyChunkSource{}, gcb, nil } - return ftp.persistTable(ctx, name, data, chunkCount, stats) + src, err := ftp.persistTable(ctx, name, data, chunkCount, stats) + if err != nil { + return emptyChunkSource{}, gcBehavior_Continue, err + } + return src, gcBehavior_Continue, nil } func (ftp *fsTablePersister) Path() string { diff --git a/go/store/nbs/file_table_persister_test.go b/go/store/nbs/file_table_persister_test.go index 00e57d2fc95..4adde94986b 100644 --- a/go/store/nbs/file_table_persister_test.go +++ b/go/store/nbs/file_table_persister_test.go @@ -96,7 +96,8 @@ func persistTableData(p tablePersister, chunx ...[]byte) (src chunkSource, err e return nil, fmt.Errorf("memTable too full to add %s", computeAddr(c)) } } - return p.Persist(context.Background(), mt, nil, &Stats{}) + src, _, err = p.Persist(context.Background(), mt, nil, nil, &Stats{}) + return src, err } func TestFSTablePersisterPersistNoData(t *testing.T) { @@ -113,7 +114,7 @@ func TestFSTablePersisterPersistNoData(t *testing.T) { defer file.RemoveAll(dir) fts := newFSTablePersister(dir, &UnlimitedQuotaProvider{}) - src, err := fts.Persist(context.Background(), mt, existingTable, &Stats{}) + src, _, err := fts.Persist(context.Background(), mt, existingTable, nil, &Stats{}) require.NoError(t, err) assert.True(mustUint32(src.count()) == 0) @@ -177,7 +178,7 @@ func TestFSTablePersisterConjoinAllDups(t *testing.T) { } var err error - sources[0], err = fts.Persist(ctx, mt, nil, &Stats{}) + sources[0], _, err = fts.Persist(ctx, mt, nil, nil, &Stats{}) require.NoError(t, err) sources[1], err = sources[0].clone() require.NoError(t, err) diff --git a/go/store/nbs/frag/main.go b/go/store/nbs/frag/main.go index 424259916e7..7eb9c2d8db8 100644 --- a/go/store/nbs/frag/main.go +++ b/go/store/nbs/frag/main.go @@ -153,14 +153,14 @@ func main() { if i+1 == numGroups { // last group go func(i int) { defer wg.Done() - reads[i], _, err = nbs.CalcReads(store, orderedChildren[i*branchFactor:].HashSet(), 0) + reads[i], _, _, err = nbs.CalcReads(store, orderedChildren[i*branchFactor:].HashSet(), 0, nil) d.PanicIfError(err) }(i) continue } go func(i int) { defer wg.Done() - reads[i], _, err = nbs.CalcReads(store, orderedChildren[i*branchFactor:(i+1)*branchFactor].HashSet(), 0) + reads[i], _, _, err = nbs.CalcReads(store, orderedChildren[i*branchFactor:(i+1)*branchFactor].HashSet(), 0, nil) d.PanicIfError(err) }(i) } diff --git a/go/store/nbs/generational_chunk_store.go b/go/store/nbs/generational_chunk_store.go index 64846797ad3..cfbbf33e4ba 100644 --- a/go/store/nbs/generational_chunk_store.go +++ b/go/store/nbs/generational_chunk_store.go @@ -118,7 +118,9 @@ func (gcs *GenerationalNBS) GetMany(ctx context.Context, hashes hash.HashSet, fo return nil } - err = gcs.newGen.GetMany(ctx, notFound, func(ctx context.Context, chunk *chunks.Chunk) { + hashes = notFound + notFound = hashes.Copy() + err = gcs.newGen.GetMany(ctx, hashes, func(ctx context.Context, chunk *chunks.Chunk) { func() { mu.Lock() defer mu.Unlock() @@ -143,14 +145,18 @@ func (gcs *GenerationalNBS) GetMany(ctx context.Context, hashes hash.HashSet, fo } func (gcs *GenerationalNBS) GetManyCompressed(ctx context.Context, hashes hash.HashSet, found func(context.Context, CompressedChunk)) error { + return gcs.getManyCompressed(ctx, hashes, found, gcDependencyMode_TakeDependency) +} + +func (gcs *GenerationalNBS) getManyCompressed(ctx context.Context, hashes hash.HashSet, found func(context.Context, CompressedChunk), gcDepMode gcDependencyMode) error { var mu sync.Mutex notInOldGen := hashes.Copy() - err := gcs.oldGen.GetManyCompressed(ctx, hashes, func(ctx context.Context, chunk CompressedChunk) { + err := gcs.oldGen.getManyCompressed(ctx, hashes, func(ctx context.Context, chunk CompressedChunk) { mu.Lock() delete(notInOldGen, chunk.Hash()) mu.Unlock() found(ctx, chunk) - }) + }, gcDepMode) if err != nil { return err } @@ -159,12 +165,12 @@ func (gcs *GenerationalNBS) GetManyCompressed(ctx context.Context, hashes hash.H } notFound := notInOldGen.Copy() - err = gcs.newGen.GetManyCompressed(ctx, notInOldGen, func(ctx context.Context, chunk CompressedChunk) { + err = gcs.newGen.getManyCompressed(ctx, notInOldGen, func(ctx context.Context, chunk CompressedChunk) { mu.Lock() delete(notFound, chunk.Hash()) mu.Unlock() found(ctx, chunk) - }) + }, gcDepMode) if err != nil { return err } @@ -174,7 +180,7 @@ func (gcs *GenerationalNBS) GetManyCompressed(ctx context.Context, hashes hash.H // The missing chunks may be ghost chunks. if gcs.ghostGen != nil { - return gcs.ghostGen.GetManyCompressed(ctx, notFound, found) + return gcs.ghostGen.getManyCompressed(ctx, notFound, found, gcDepMode) } return nil } @@ -202,14 +208,30 @@ func (gcs *GenerationalNBS) Has(ctx context.Context, h hash.Hash) (bool, error) } // HasMany returns a new HashSet containing any members of |hashes| that are absent from the store. -func (gcs *GenerationalNBS) HasMany(ctx context.Context, hashes hash.HashSet) (absent hash.HashSet, err error) { - gcs.newGen.mu.RLock() - defer gcs.newGen.mu.RUnlock() - return gcs.hasMany(toHasRecords(hashes)) +func (gcs *GenerationalNBS) HasMany(ctx context.Context, hashes hash.HashSet) (hash.HashSet, error) { + absent, err := gcs.newGen.HasMany(ctx, hashes) + if err != nil { + return nil, err + } + if len(absent) == 0 { + return nil, err + } + + absent, err = gcs.oldGen.HasMany(ctx, absent) + if err != nil { + return nil, err + } + if len(absent) == 0 || gcs.ghostGen == nil { + return nil, err + } + + return gcs.ghostGen.HasMany(ctx, absent) } -func (gcs *GenerationalNBS) hasMany(recs []hasRecord) (absent hash.HashSet, err error) { - absent, err = gcs.newGen.hasMany(recs) +// |refCheck| is called from write processes in newGen, so it is called with +// newGen.mu held. oldGen.mu is not held however. +func (gcs *GenerationalNBS) refCheck(recs []hasRecord) (hash.HashSet, error) { + absent, err := gcs.newGen.refCheck(recs) if err != nil { return nil, err } else if len(absent) == 0 { @@ -219,12 +241,11 @@ func (gcs *GenerationalNBS) hasMany(recs []hasRecord) (absent hash.HashSet, err absent, err = func() (hash.HashSet, error) { gcs.oldGen.mu.RLock() defer gcs.oldGen.mu.RUnlock() - return gcs.oldGen.hasMany(recs) + return gcs.oldGen.refCheck(recs) }() if err != nil { return nil, err } - if len(absent) == 0 || gcs.ghostGen == nil { return absent, nil } @@ -237,7 +258,7 @@ func (gcs *GenerationalNBS) hasMany(recs []hasRecord) (absent hash.HashSet, err // to Flush(). Put may be called concurrently with other calls to Put(), // Get(), GetMany(), Has() and HasMany(). func (gcs *GenerationalNBS) Put(ctx context.Context, c chunks.Chunk, getAddrs chunks.GetAddrsCurry) error { - return gcs.newGen.putChunk(ctx, c, getAddrs, gcs.hasMany) + return gcs.newGen.putChunk(ctx, c, getAddrs, gcs.refCheck) } // Returns the NomsBinFormat with which this ChunkSource is compatible. @@ -277,7 +298,7 @@ func (gcs *GenerationalNBS) Root(ctx context.Context) (hash.Hash, error) { // persisted root hash from last to current (or keeps it the same). // If last doesn't match the root in persistent storage, returns false. func (gcs *GenerationalNBS) Commit(ctx context.Context, current, last hash.Hash) (bool, error) { - return gcs.newGen.commit(ctx, current, last, gcs.hasMany) + return gcs.newGen.commit(ctx, current, last, gcs.refCheck) } // Stats may return some kind of struct that reports statistics about the @@ -400,18 +421,18 @@ func (gcs *GenerationalNBS) AddTableFilesToManifest(ctx context.Context, fileIdT // PruneTableFiles deletes old table files that are no longer referenced in the manifest of the new or old gen chunkstores func (gcs *GenerationalNBS) PruneTableFiles(ctx context.Context) error { - err := gcs.oldGen.pruneTableFiles(ctx, gcs.hasMany) + err := gcs.oldGen.pruneTableFiles(ctx) if err != nil { return err } - return gcs.newGen.pruneTableFiles(ctx, gcs.hasMany) + return gcs.newGen.pruneTableFiles(ctx) } // SetRootChunk changes the root chunk hash from the previous value to the new root for the newgen cs func (gcs *GenerationalNBS) SetRootChunk(ctx context.Context, root, previous hash.Hash) error { - return gcs.newGen.setRootChunk(ctx, root, previous, gcs.hasMany) + return gcs.newGen.setRootChunk(ctx, root, previous, gcs.refCheck) } // SupportedOperations returns a description of the support TableFile operations. Some stores only support reading table files, not writing. @@ -473,12 +494,37 @@ func (gcs *GenerationalNBS) UpdateManifest(ctx context.Context, updates map[hash return gcs.newGen.UpdateManifest(ctx, updates) } -func (gcs *GenerationalNBS) BeginGC(keeper func(hash.Hash) bool) error { - return gcs.newGen.BeginGC(keeper) +func (gcs *GenerationalNBS) OldGenGCFilter() chunks.HasManyFunc { + return func(ctx context.Context, hashes hash.HashSet) (hash.HashSet, error) { + return gcs.oldGen.hasManyDep(ctx, hashes, gcDependencyMode_NoDependency) + } +} + +func (gcs *GenerationalNBS) BeginGC(keeper func(hash.Hash) bool, mode chunks.GCMode) error { + err := gcs.newGen.BeginGC(keeper, mode) + if err != nil { + return err + } + // In GCMode_Full, the OldGen is also being collected. In normal + // operation, the OldGen is not being collected because it is + // still growing monotonically and nothing in it is at risk of + // going away. In Full mode, we want to take read dependencies + // from the OldGen as well. + if mode == chunks.GCMode_Full { + err = gcs.oldGen.BeginGC(keeper, mode) + if err != nil { + gcs.newGen.EndGC(mode) + return err + } + } + return nil } -func (gcs *GenerationalNBS) EndGC() { - gcs.newGen.EndGC() +func (gcs *GenerationalNBS) EndGC(mode chunks.GCMode) { + if mode == chunks.GCMode_Full { + gcs.oldGen.EndGC(mode) + } + gcs.newGen.EndGC(mode) } func (gcs *GenerationalNBS) MarkAndSweepChunks(ctx context.Context, getAddrs chunks.GetAddrsCurry, filter chunks.HasManyFunc, dest chunks.ChunkStore, mode chunks.GCMode) (chunks.MarkAndSweeper, error) { diff --git a/go/store/nbs/ghost_store.go b/go/store/nbs/ghost_store.go index 11d23de6a68..9edd0fb40fa 100644 --- a/go/store/nbs/ghost_store.go +++ b/go/store/nbs/ghost_store.go @@ -91,6 +91,10 @@ func (g GhostBlockStore) GetMany(ctx context.Context, hashes hash.HashSet, found } func (g GhostBlockStore) GetManyCompressed(ctx context.Context, hashes hash.HashSet, found func(context.Context, CompressedChunk)) error { + return g.getManyCompressed(ctx, hashes, found, gcDependencyMode_TakeDependency) +} + +func (g GhostBlockStore) getManyCompressed(ctx context.Context, hashes hash.HashSet, found func(context.Context, CompressedChunk), gcDepMode gcDependencyMode) error { for h := range hashes { if g.skippedRefs.Has(h) { found(ctx, NewGhostCompressedChunk(h)) diff --git a/go/store/nbs/journal.go b/go/store/nbs/journal.go index 8f415cbfa3d..dc8deff8e19 100644 --- a/go/store/nbs/journal.go +++ b/go/store/nbs/journal.go @@ -239,17 +239,19 @@ func (j *ChunkJournal) IterateRoots(f func(root string, timestamp *time.Time) er } // Persist implements tablePersister. -func (j *ChunkJournal) Persist(ctx context.Context, mt *memTable, haver chunkReader, stats *Stats) (chunkSource, error) { +func (j *ChunkJournal) Persist(ctx context.Context, mt *memTable, haver chunkReader, keeper keeperF, stats *Stats) (chunkSource, gcBehavior, error) { if j.backing.readOnly() { - return nil, errReadOnlyManifest + return nil, gcBehavior_Continue, errReadOnlyManifest } else if err := j.maybeInit(ctx); err != nil { - return nil, err + return nil, gcBehavior_Continue, err } if haver != nil { sort.Sort(hasRecordByPrefix(mt.order)) // hasMany() requires addresses to be sorted. - if _, err := haver.hasMany(mt.order); err != nil { - return nil, err + if _, gcb, err := haver.hasMany(mt.order, keeper); err != nil { + return nil, gcBehavior_Continue, err + } else if gcb != gcBehavior_Continue { + return nil, gcb, nil } sort.Sort(hasRecordByOrder(mt.order)) // restore "insertion" order for write } @@ -261,10 +263,10 @@ func (j *ChunkJournal) Persist(ctx context.Context, mt *memTable, haver chunkRea c := chunks.NewChunkWithHash(hash.Hash(*record.a), mt.chunks[*record.a]) err := j.wr.writeCompressedChunk(ctx, ChunkToCompressedChunk(c)) if err != nil { - return nil, err + return nil, gcBehavior_Continue, err } } - return journalChunkSource{journal: j.wr}, nil + return journalChunkSource{journal: j.wr}, gcBehavior_Continue, nil } // ConjoinAll implements tablePersister. diff --git a/go/store/nbs/journal_chunk_source.go b/go/store/nbs/journal_chunk_source.go index c8dd8a4ac02..e7bf50dc4a9 100644 --- a/go/store/nbs/journal_chunk_source.go +++ b/go/store/nbs/journal_chunk_source.go @@ -39,20 +39,29 @@ type journalChunkSource struct { var _ chunkSource = journalChunkSource{} -func (s journalChunkSource) has(h hash.Hash) (bool, error) { - return s.journal.hasAddr(h), nil +func (s journalChunkSource) has(h hash.Hash, keeper keeperF) (bool, gcBehavior, error) { + res := s.journal.hasAddr(h) + if res && keeper != nil && keeper(h) { + return false, gcBehavior_Block, nil + } + return res, gcBehavior_Continue, nil } -func (s journalChunkSource) hasMany(addrs []hasRecord) (missing bool, err error) { +func (s journalChunkSource) hasMany(addrs []hasRecord, keeper keeperF) (bool, gcBehavior, error) { + missing := false for i := range addrs { - ok := s.journal.hasAddr(*addrs[i].a) + h := *addrs[i].a + ok := s.journal.hasAddr(h) if ok { + if keeper != nil && keeper(h) { + return true, gcBehavior_Block, nil + } addrs[i].has = true } else { missing = true } } - return + return missing, gcBehavior_Continue, nil } func (s journalChunkSource) getCompressed(ctx context.Context, h hash.Hash, _ *Stats) (CompressedChunk, error) { @@ -60,20 +69,23 @@ func (s journalChunkSource) getCompressed(ctx context.Context, h hash.Hash, _ *S return s.journal.getCompressedChunk(h) } -func (s journalChunkSource) get(ctx context.Context, h hash.Hash, _ *Stats) ([]byte, error) { +func (s journalChunkSource) get(ctx context.Context, h hash.Hash, keeper keeperF, _ *Stats) ([]byte, gcBehavior, error) { defer trace.StartRegion(ctx, "journalChunkSource.get").End() cc, err := s.journal.getCompressedChunk(h) if err != nil { - return nil, err + return nil, gcBehavior_Continue, err } else if cc.IsEmpty() { - return nil, nil + return nil, gcBehavior_Continue, nil + } + if keeper != nil && keeper(h) { + return nil, gcBehavior_Block, nil } ch, err := cc.ToChunk() if err != nil { - return nil, err + return nil, gcBehavior_Continue, err } - return ch.Data(), nil + return ch.Data(), gcBehavior_Continue, nil } type journalRecord struct { @@ -83,7 +95,7 @@ type journalRecord struct { idx int } -func (s journalChunkSource) getMany(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), stats *Stats) (bool, error) { +func (s journalChunkSource) getMany(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), keeper keeperF, stats *Stats) (bool, gcBehavior, error) { return s.getManyCompressed(ctx, eg, reqs, func(ctx context.Context, cc CompressedChunk) { ch, err := cc.ToChunk() if err != nil { @@ -94,7 +106,7 @@ func (s journalChunkSource) getMany(ctx context.Context, eg *errgroup.Group, req } chWHash := chunks.NewChunkWithHash(cc.Hash(), ch.Data()) found(ctx, &chWHash) - }, stats) + }, keeper, stats) } // getManyCompressed implements chunkReader. Here we (1) synchronously check @@ -103,7 +115,7 @@ func (s journalChunkSource) getMany(ctx context.Context, eg *errgroup.Group, req // and then (4) asynchronously perform reads. We release the journal read // lock after returning when all reads are completed, which can be after the // function returns. -func (s journalChunkSource) getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), stats *Stats) (bool, error) { +func (s journalChunkSource) getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), keeper keeperF, stats *Stats) (bool, gcBehavior, error) { defer trace.StartRegion(ctx, "journalChunkSource.getManyCompressed").End() var remaining bool @@ -114,11 +126,16 @@ func (s journalChunkSource) getManyCompressed(ctx context.Context, eg *errgroup. if r.found { continue } - rang, ok := s.journal.ranges.get(*r.a) + h := *r.a + rang, ok := s.journal.ranges.get(h) if !ok { remaining = true continue } + if keeper != nil && keeper(h) { + s.journal.lock.RUnlock() + return true, gcBehavior_Block, nil + } jReqs = append(jReqs, journalRecord{r: rang, idx: i}) reqs[i].found = true } @@ -150,7 +167,7 @@ func (s journalChunkSource) getManyCompressed(ctx context.Context, eg *errgroup. wg.Wait() s.journal.lock.RUnlock() }() - return remaining, nil + return remaining, gcBehavior_Continue, nil } func (s journalChunkSource) count() (uint32, error) { @@ -171,22 +188,26 @@ func (s journalChunkSource) reader(ctx context.Context) (io.ReadCloser, uint64, return rdr, uint64(sz), err } -func (s journalChunkSource) getRecordRanges(ctx context.Context, requests []getRecord) (map[hash.Hash]Range, error) { +func (s journalChunkSource) getRecordRanges(ctx context.Context, requests []getRecord, keeper keeperF) (map[hash.Hash]Range, gcBehavior, error) { ranges := make(map[hash.Hash]Range, len(requests)) for _, req := range requests { if req.found { continue } - rng, ok, err := s.journal.getRange(ctx, *req.a) + h := *req.a + rng, ok, err := s.journal.getRange(ctx, h) if err != nil { - return nil, err + return nil, gcBehavior_Continue, err } else if !ok { continue } + if keeper != nil && keeper(h) { + return nil, gcBehavior_Block, nil + } req.found = true // update |requests| - ranges[hash.Hash(*req.a)] = rng + ranges[h] = rng } - return ranges, nil + return ranges, gcBehavior_Continue, nil } // size implements chunkSource. diff --git a/go/store/nbs/journal_test.go b/go/store/nbs/journal_test.go index 9486f1edf17..603d1610d43 100644 --- a/go/store/nbs/journal_test.go +++ b/go/store/nbs/journal_test.go @@ -67,14 +67,14 @@ func TestChunkJournalPersist(t *testing.T) { haver := emptyChunkSource{} for i := 0; i < iters; i++ { memTbl, chunkMap := randomMemTable(16) - source, err := j.Persist(ctx, memTbl, haver, stats) + source, _, err := j.Persist(ctx, memTbl, haver, nil, stats) assert.NoError(t, err) for h, ch := range chunkMap { - ok, err := source.has(h) + ok, _, err := source.has(h, nil) assert.NoError(t, err) assert.True(t, ok) - data, err := source.get(ctx, h, stats) + data, _, err := source.get(ctx, h, nil, stats) assert.NoError(t, err) assert.Equal(t, ch.Data(), data) } @@ -96,7 +96,7 @@ func TestReadRecordRanges(t *testing.T) { gets = append(gets, getRecord{a: &h, prefix: h.Prefix()}) } - jcs, err := j.Persist(ctx, mt, emptyChunkSource{}, &Stats{}) + jcs, _, err := j.Persist(ctx, mt, emptyChunkSource{}, nil, &Stats{}) require.NoError(t, err) rdr, sz, err := jcs.(journalChunkSource).journal.snapshot(context.Background()) @@ -108,11 +108,11 @@ func TestReadRecordRanges(t *testing.T) { require.NoError(t, err) assert.Equal(t, int(sz), n) - ranges, err := jcs.getRecordRanges(ctx, gets) + ranges, _, err := jcs.getRecordRanges(ctx, gets, nil) require.NoError(t, err) for h, rng := range ranges { - b, err := jcs.get(ctx, h, &Stats{}) + b, _, err := jcs.get(ctx, h, nil, &Stats{}) assert.NoError(t, err) ch1 := chunks.NewChunkWithHash(h, b) assert.Equal(t, data[h], ch1) diff --git a/go/store/nbs/journal_writer_test.go b/go/store/nbs/journal_writer_test.go index df8c45946f3..77263f23ab2 100644 --- a/go/store/nbs/journal_writer_test.go +++ b/go/store/nbs/journal_writer_test.go @@ -228,7 +228,7 @@ func TestJournalWriterBootstrap(t *testing.T) { source := journalChunkSource{journal: j} for a, cc := range data { - buf, err := source.get(ctx, a, nil) + buf, _, err := source.get(ctx, a, nil, nil) require.NoError(t, err) ch, err := cc.ToChunk() require.NoError(t, err) @@ -279,9 +279,7 @@ func TestJournalWriterSyncClose(t *testing.T) { } func newTestFilePath(t *testing.T) string { - path, err := os.MkdirTemp("", "") - require.NoError(t, err) - return filepath.Join(path, "journal.log") + return filepath.Join(t.TempDir(), "journal.log") } func TestJournalIndexBootstrap(t *testing.T) { @@ -398,6 +396,8 @@ func TestJournalIndexBootstrap(t *testing.T) { require.True(t, ok) _, err = jnl.bootstrapJournal(ctx, nil) assert.Error(t, err) + err = jnl.Close() + require.NoError(t, err) }) } } diff --git a/go/store/nbs/mem_table.go b/go/store/nbs/mem_table.go index cbffa34a72b..1fd8c0ffcda 100644 --- a/go/store/nbs/mem_table.go +++ b/go/store/nbs/mem_table.go @@ -61,7 +61,7 @@ func writeChunksToMT(mt *memTable, chunks []chunks.Chunk) (string, []byte, error } var stats Stats - name, data, count, err := mt.write(nil, &stats) + name, data, count, _, err := mt.write(nil, nil, &stats) if err != nil { return "", nil, err @@ -135,22 +135,27 @@ func (mt *memTable) uncompressedLen() (uint64, error) { return mt.totalData, nil } -func (mt *memTable) has(h hash.Hash) (bool, error) { +func (mt *memTable) has(h hash.Hash, keeper keeperF) (bool, gcBehavior, error) { _, has := mt.chunks[h] - return has, nil + if has && keeper != nil && keeper(h) { + return false, gcBehavior_Block, nil + } + return has, gcBehavior_Continue, nil } -func (mt *memTable) hasMany(addrs []hasRecord) (bool, error) { +func (mt *memTable) hasMany(addrs []hasRecord, keeper keeperF) (bool, gcBehavior, error) { var remaining bool for i, addr := range addrs { if addr.has { continue } - ok, err := mt.has(*addr.a) - + ok, gcb, err := mt.has(*addr.a, keeper) if err != nil { - return false, err + return false, gcBehavior_Continue, err + } + if gcb != gcBehavior_Continue { + return ok, gcb, nil } if ok { @@ -159,18 +164,25 @@ func (mt *memTable) hasMany(addrs []hasRecord) (bool, error) { remaining = true } } - return remaining, nil + return remaining, gcBehavior_Continue, nil } -func (mt *memTable) get(ctx context.Context, h hash.Hash, stats *Stats) ([]byte, error) { - return mt.chunks[h], nil +func (mt *memTable) get(ctx context.Context, h hash.Hash, keeper keeperF, stats *Stats) ([]byte, gcBehavior, error) { + c, ok := mt.chunks[h] + if ok && keeper != nil && keeper(h) { + return nil, gcBehavior_Block, nil + } + return c, gcBehavior_Continue, nil } -func (mt *memTable) getMany(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), stats *Stats) (bool, error) { +func (mt *memTable) getMany(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), keeper keeperF, stats *Stats) (bool, gcBehavior, error) { var remaining bool for i, r := range reqs { data := mt.chunks[*r.a] if data != nil { + if keeper != nil && keeper(*r.a) { + return true, gcBehavior_Block, nil + } c := chunks.NewChunkWithHash(hash.Hash(*r.a), data) reqs[i].found = true found(ctx, &c) @@ -178,14 +190,17 @@ func (mt *memTable) getMany(ctx context.Context, eg *errgroup.Group, reqs []getR remaining = true } } - return remaining, nil + return remaining, gcBehavior_Continue, nil } -func (mt *memTable) getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), stats *Stats) (bool, error) { +func (mt *memTable) getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), keeper keeperF, stats *Stats) (bool, gcBehavior, error) { var remaining bool for i, r := range reqs { data := mt.chunks[*r.a] if data != nil { + if keeper != nil && keeper(*r.a) { + return true, gcBehavior_Block, nil + } c := chunks.NewChunkWithHash(hash.Hash(*r.a), data) reqs[i].found = true found(ctx, ChunkToCompressedChunk(c)) @@ -194,7 +209,7 @@ func (mt *memTable) getManyCompressed(ctx context.Context, eg *errgroup.Group, r } } - return remaining, nil + return remaining, gcBehavior_Continue, nil } func (mt *memTable) extract(ctx context.Context, chunks chan<- extractRecord) error { @@ -205,10 +220,11 @@ func (mt *memTable) extract(ctx context.Context, chunks chan<- extractRecord) er return nil } -func (mt *memTable) write(haver chunkReader, stats *Stats) (name hash.Hash, data []byte, count uint32, err error) { +func (mt *memTable) write(haver chunkReader, keeper keeperF, stats *Stats) (name hash.Hash, data []byte, count uint32, gcb gcBehavior, err error) { + gcb = gcBehavior_Continue numChunks := uint64(len(mt.order)) if numChunks == 0 { - return hash.Hash{}, nil, 0, fmt.Errorf("mem table cannot write with zero chunks") + return hash.Hash{}, nil, 0, gcBehavior_Continue, fmt.Errorf("mem table cannot write with zero chunks") } maxSize := maxTableSize(uint64(len(mt.order)), mt.totalData) // todo: memory quota @@ -217,10 +233,12 @@ func (mt *memTable) write(haver chunkReader, stats *Stats) (name hash.Hash, data if haver != nil { sort.Sort(hasRecordByPrefix(mt.order)) // hasMany() requires addresses to be sorted. - _, err := haver.hasMany(mt.order) - + _, gcb, err = haver.hasMany(mt.order, keeper) if err != nil { - return hash.Hash{}, nil, 0, err + return hash.Hash{}, nil, 0, gcBehavior_Continue, err + } + if gcb != gcBehavior_Continue { + return hash.Hash{}, nil, 0, gcb, err } sort.Sort(hasRecordByOrder(mt.order)) // restore "insertion" order for write @@ -236,7 +254,7 @@ func (mt *memTable) write(haver chunkReader, stats *Stats) (name hash.Hash, data tableSize, name, err := tw.finish() if err != nil { - return hash.Hash{}, nil, 0, err + return hash.Hash{}, nil, 0, gcBehavior_Continue, err } if count > 0 { @@ -246,7 +264,7 @@ func (mt *memTable) write(haver chunkReader, stats *Stats) (name hash.Hash, data stats.ChunksPerPersist.Sample(uint64(count)) } - return name, buff[:tableSize], count, nil + return name, buff[:tableSize], count, gcBehavior_Continue, nil } func (mt *memTable) close() error { diff --git a/go/store/nbs/mem_table_test.go b/go/store/nbs/mem_table_test.go index 250f2994636..647a8395dce 100644 --- a/go/store/nbs/mem_table_test.go +++ b/go/store/nbs/mem_table_test.go @@ -69,12 +69,7 @@ func TestWriteChunks(t *testing.T) { t.Error(err) } - dir, err := os.MkdirTemp("", "write_chunks_test") - if err != nil { - t.Error(err) - } - - err = os.WriteFile(dir+name, data, os.ModePerm) + err = os.WriteFile(t.TempDir()+name, data, os.ModePerm) if err != nil { t.Error(err) } @@ -97,14 +92,14 @@ func TestMemTableAddHasGetChunk(t *testing.T) { assertChunksInReader(chunks, mt, assert) for _, c := range chunks { - data, err := mt.get(context.Background(), computeAddr(c), &Stats{}) + data, _, err := mt.get(context.Background(), computeAddr(c), nil, &Stats{}) require.NoError(t, err) assert.Equal(bytes.Compare(c, data), 0) } notPresent := []byte("nope") - assert.False(mt.has(computeAddr(notPresent))) - assert.Nil(mt.get(context.Background(), computeAddr(notPresent), &Stats{})) + assert.False(mt.has(computeAddr(notPresent), nil)) + assert.Nil(mt.get(context.Background(), computeAddr(notPresent), nil, &Stats{})) } func TestMemTableAddOverflowChunk(t *testing.T) { @@ -117,9 +112,9 @@ func TestMemTableAddOverflowChunk(t *testing.T) { bigAddr := computeAddr(big) mt := newMemTable(memTableSize) assert.Equal(mt.addChunk(bigAddr, big), chunkAdded) - assert.True(mt.has(bigAddr)) + assert.True(mt.has(bigAddr, nil)) assert.Equal(mt.addChunk(computeAddr(little), little), chunkNotAdded) - assert.False(mt.has(computeAddr(little))) + assert.False(mt.has(computeAddr(little), nil)) } { @@ -127,12 +122,12 @@ func TestMemTableAddOverflowChunk(t *testing.T) { bigAddr := computeAddr(big) mt := newMemTable(memTableSize) assert.Equal(mt.addChunk(bigAddr, big), chunkAdded) - assert.True(mt.has(bigAddr)) + assert.True(mt.has(bigAddr, nil)) assert.Equal(mt.addChunk(computeAddr(little), little), chunkAdded) - assert.True(mt.has(computeAddr(little))) + assert.True(mt.has(computeAddr(little), nil)) other := []byte("o") assert.Equal(mt.addChunk(computeAddr(other), other), chunkNotAdded) - assert.False(mt.has(computeAddr(other))) + assert.False(mt.has(computeAddr(other), nil)) } } @@ -158,7 +153,7 @@ func TestMemTableWrite(t *testing.T) { tr1, err := newTableReader(ti1, tableReaderAtFromBytes(td1), fileBlockSize) require.NoError(t, err) defer tr1.close() - assert.True(tr1.has(computeAddr(chunks[1]))) + assert.True(tr1.has(computeAddr(chunks[1]), nil)) td2, _, err := buildTable(chunks[2:]) require.NoError(t, err) @@ -167,9 +162,9 @@ func TestMemTableWrite(t *testing.T) { tr2, err := newTableReader(ti2, tableReaderAtFromBytes(td2), fileBlockSize) require.NoError(t, err) defer tr2.close() - assert.True(tr2.has(computeAddr(chunks[2]))) + assert.True(tr2.has(computeAddr(chunks[2]), nil)) - _, data, count, err := mt.write(chunkReaderGroup{tr1, tr2}, &Stats{}) + _, data, count, _, err := mt.write(chunkReaderGroup{tr1, tr2}, nil, &Stats{}) require.NoError(t, err) assert.Equal(uint32(1), count) @@ -178,9 +173,9 @@ func TestMemTableWrite(t *testing.T) { outReader, err := newTableReader(ti, tableReaderAtFromBytes(data), fileBlockSize) require.NoError(t, err) defer outReader.close() - assert.True(outReader.has(computeAddr(chunks[0]))) - assert.False(outReader.has(computeAddr(chunks[1]))) - assert.False(outReader.has(computeAddr(chunks[2]))) + assert.True(outReader.has(computeAddr(chunks[0]), nil)) + assert.False(outReader.has(computeAddr(chunks[1]), nil)) + assert.False(outReader.has(computeAddr(chunks[2]), nil)) } type tableReaderAtAdapter struct { @@ -223,7 +218,7 @@ func TestMemTableSnappyWriteOutOfLine(t *testing.T) { } mt.snapper = &outOfLineSnappy{[]bool{false, true, false}} // chunks[1] should trigger a panic - assert.Panics(func() { mt.write(nil, &Stats{}) }) + assert.Panics(func() { mt.write(nil, nil, &Stats{}) }) } type outOfLineSnappy struct { @@ -244,72 +239,82 @@ func (o *outOfLineSnappy) Encode(dst, src []byte) []byte { type chunkReaderGroup []chunkReader -func (crg chunkReaderGroup) has(h hash.Hash) (bool, error) { +func (crg chunkReaderGroup) has(h hash.Hash, keeper keeperF) (bool, gcBehavior, error) { for _, haver := range crg { - ok, err := haver.has(h) - + ok, gcb, err := haver.has(h, keeper) if err != nil { - return false, err + return false, gcb, err + } + if gcb != gcBehavior_Continue { + return true, gcb, nil } if ok { - return true, nil + return true, gcb, nil } } - - return false, nil + return false, gcBehavior_Continue, nil } -func (crg chunkReaderGroup) get(ctx context.Context, h hash.Hash, stats *Stats) ([]byte, error) { +func (crg chunkReaderGroup) get(ctx context.Context, h hash.Hash, keeper keeperF, stats *Stats) ([]byte, gcBehavior, error) { for _, haver := range crg { - if data, err := haver.get(ctx, h, stats); err != nil { - return nil, err + if data, gcb, err := haver.get(ctx, h, keeper, stats); err != nil { + return nil, gcb, err + } else if gcb != gcBehavior_Continue { + return nil, gcb, nil } else if data != nil { - return data, nil + return data, gcb, nil } } - return nil, nil + return nil, gcBehavior_Continue, nil } -func (crg chunkReaderGroup) hasMany(addrs []hasRecord) (bool, error) { +func (crg chunkReaderGroup) hasMany(addrs []hasRecord, keeper keeperF) (bool, gcBehavior, error) { for _, haver := range crg { - remaining, err := haver.hasMany(addrs) - + remaining, gcb, err := haver.hasMany(addrs, keeper) if err != nil { - return false, err + return false, gcb, err + } + if gcb != gcBehavior_Continue { + return false, gcb, nil } - if !remaining { - return false, nil + return false, gcb, nil } } - return true, nil + return true, gcBehavior_Continue, nil } -func (crg chunkReaderGroup) getMany(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), stats *Stats) (bool, error) { +func (crg chunkReaderGroup) getMany(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), keeper keeperF, stats *Stats) (bool, gcBehavior, error) { for _, haver := range crg { - remaining, err := haver.getMany(ctx, eg, reqs, found, stats) + remaining, gcb, err := haver.getMany(ctx, eg, reqs, found, keeper, stats) if err != nil { - return true, err + return true, gcb, err + } + if gcb != gcBehavior_Continue { + return true, gcb, nil } if !remaining { - return false, nil + return false, gcb, nil } } - return true, nil + return true, gcBehavior_Continue, nil } -func (crg chunkReaderGroup) getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), stats *Stats) (bool, error) { +func (crg chunkReaderGroup) getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), keeper keeperF, stats *Stats) (bool, gcBehavior, error) { for _, haver := range crg { - remaining, err := haver.getManyCompressed(ctx, eg, reqs, found, stats) + remaining, gcb, err := haver.getManyCompressed(ctx, eg, reqs, found, keeper, stats) if err != nil { - return true, err + return true, gcb, err + } + if gcb != gcBehavior_Continue { + return true, gcb, nil } if !remaining { - return false, nil + return false, gcb, nil } } - return true, nil + return true, gcBehavior_Continue, nil } func (crg chunkReaderGroup) count() (count uint32, err error) { diff --git a/go/store/nbs/nbs_metrics_wrapper.go b/go/store/nbs/nbs_metrics_wrapper.go index 1ca852da04d..36b262075b7 100644 --- a/go/store/nbs/nbs_metrics_wrapper.go +++ b/go/store/nbs/nbs_metrics_wrapper.go @@ -71,12 +71,12 @@ func (nbsMW *NBSMetricWrapper) SupportedOperations() chunks.TableFileStoreOps { return nbsMW.nbs.SupportedOperations() } -func (nbsMW *NBSMetricWrapper) BeginGC(keeper func(hash.Hash) bool) error { - return nbsMW.nbs.BeginGC(keeper) +func (nbsMW *NBSMetricWrapper) BeginGC(keeper func(hash.Hash) bool, mode chunks.GCMode) error { + return nbsMW.nbs.BeginGC(keeper, mode) } -func (nbsMW *NBSMetricWrapper) EndGC() { - nbsMW.nbs.EndGC() +func (nbsMW *NBSMetricWrapper) EndGC(mode chunks.GCMode) { + nbsMW.nbs.EndGC(mode) } func (nbsMW *NBSMetricWrapper) MarkAndSweepChunks(ctx context.Context, getAddrs chunks.GetAddrsCurry, filter chunks.HasManyFunc, dest chunks.ChunkStore, mode chunks.GCMode) (chunks.MarkAndSweeper, error) { diff --git a/go/store/nbs/no_conjoin_bs_persister.go b/go/store/nbs/no_conjoin_bs_persister.go index 053c9be710e..98ed3a06a5c 100644 --- a/go/store/nbs/no_conjoin_bs_persister.go +++ b/go/store/nbs/no_conjoin_bs_persister.go @@ -21,7 +21,6 @@ import ( "io" "time" - "github.com/fatih/color" "golang.org/x/sync/errgroup" "github.com/dolthub/dolt/go/store/blobstore" @@ -40,27 +39,32 @@ var _ tableFilePersister = &noConjoinBlobstorePersister{} // Persist makes the contents of mt durable. Chunks already present in // |haver| may be dropped in the process. -func (bsp *noConjoinBlobstorePersister) Persist(ctx context.Context, mt *memTable, haver chunkReader, stats *Stats) (chunkSource, error) { - address, data, chunkCount, err := mt.write(haver, stats) +func (bsp *noConjoinBlobstorePersister) Persist(ctx context.Context, mt *memTable, haver chunkReader, keeper keeperF, stats *Stats) (chunkSource, gcBehavior, error) { + address, data, chunkCount, gcb, err := mt.write(haver, keeper, stats) if err != nil { - return emptyChunkSource{}, err + return emptyChunkSource{}, gcBehavior_Continue, err + } else if gcb != gcBehavior_Continue { + return emptyChunkSource{}, gcb, nil } else if chunkCount == 0 { - return emptyChunkSource{}, nil + return emptyChunkSource{}, gcBehavior_Continue, nil } name := address.String() eg, ectx := errgroup.WithContext(ctx) - eg.Go(func() (err error) { - fmt.Fprintf(color.Output, "Persist: bs.Put: name: %s\n", name) - _, err = bsp.bs.Put(ectx, name, int64(len(data)), bytes.NewBuffer(data)) - return + eg.Go(func() error { + _, err := bsp.bs.Put(ectx, name, int64(len(data)), bytes.NewBuffer(data)) + return err }) if err = eg.Wait(); err != nil { - return nil, err + return nil, gcBehavior_Continue, err } rdr := &bsTableReaderAt{name, bsp.bs} - return newReaderFromIndexData(ctx, bsp.q, data, address, rdr, bsp.blockSize) + src, err := newReaderFromIndexData(ctx, bsp.q, data, address, rdr, bsp.blockSize) + if err != nil { + return nil, gcBehavior_Continue, err + } + return src, gcBehavior_Continue, nil } // ConjoinAll implements tablePersister. diff --git a/go/store/nbs/root_tracker_test.go b/go/store/nbs/root_tracker_test.go index f428f0817eb..37d82483743 100644 --- a/go/store/nbs/root_tracker_test.go +++ b/go/store/nbs/root_tracker_test.go @@ -399,7 +399,7 @@ func interloperWrite(fm *fakeManifest, p tablePersister, rootChunk []byte, chunk persisted = append(chunks, rootChunk) var src chunkSource - src, err = p.Persist(context.Background(), createMemTable(persisted), nil, &Stats{}) + src, _, err = p.Persist(context.Background(), createMemTable(persisted), nil, nil, &Stats{}) if err != nil { return hash.Hash{}, nil, err } @@ -505,16 +505,18 @@ type fakeTablePersister struct { var _ tablePersister = fakeTablePersister{} -func (ftp fakeTablePersister) Persist(ctx context.Context, mt *memTable, haver chunkReader, stats *Stats) (chunkSource, error) { +func (ftp fakeTablePersister) Persist(ctx context.Context, mt *memTable, haver chunkReader, keeper keeperF, stats *Stats) (chunkSource, gcBehavior, error) { if mustUint32(mt.count()) == 0 { - return emptyChunkSource{}, nil + return emptyChunkSource{}, gcBehavior_Continue, nil } - name, data, chunkCount, err := mt.write(haver, stats) + name, data, chunkCount, gcb, err := mt.write(haver, keeper, stats) if err != nil { - return emptyChunkSource{}, err + return emptyChunkSource{}, gcBehavior_Continue, err + } else if gcb != gcBehavior_Continue { + return emptyChunkSource{}, gcb, nil } else if chunkCount == 0 { - return emptyChunkSource{}, nil + return emptyChunkSource{}, gcBehavior_Continue, nil } ftp.mu.Lock() @@ -523,14 +525,14 @@ func (ftp fakeTablePersister) Persist(ctx context.Context, mt *memTable, haver c ti, err := parseTableIndexByCopy(ctx, data, ftp.q) if err != nil { - return nil, err + return nil, gcBehavior_Continue, err } cs, err := newTableReader(ti, tableReaderAtFromBytes(data), fileBlockSize) if err != nil { - return emptyChunkSource{}, err + return emptyChunkSource{}, gcBehavior_Continue, err } - return chunkSourceAdapter{cs, name}, nil + return chunkSourceAdapter{cs, name}, gcBehavior_Continue, nil } func (ftp fakeTablePersister) ConjoinAll(ctx context.Context, sources chunkSources, stats *Stats) (chunkSource, cleanupFunc, error) { @@ -661,7 +663,7 @@ func extractAllChunks(ctx context.Context, src chunkSource, cb func(rec extractR return err } - data, err := src.get(ctx, h, nil) + data, _, err := src.get(ctx, h, nil, nil) if err != nil { return err } diff --git a/go/store/nbs/store.go b/go/store/nbs/store.go index 7f08eb62c99..a9c4911e360 100644 --- a/go/store/nbs/store.go +++ b/go/store/nbs/store.go @@ -23,6 +23,7 @@ package nbs import ( "context" + "errors" "fmt" "io" "os" @@ -39,7 +40,6 @@ import ( lru "github.com/hashicorp/golang-lru/v2" "github.com/oracle/oci-go-sdk/v65/common" "github.com/oracle/oci-go-sdk/v65/objectstorage" - "github.com/pkg/errors" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" @@ -89,6 +89,17 @@ type NBSCompressedChunkStore interface { GetManyCompressed(context.Context, hash.HashSet, func(context.Context, CompressedChunk)) error } +type gcDependencyMode int + +const ( + gcDependencyMode_TakeDependency gcDependencyMode = iota + gcDependencyMode_NoDependency +) + +type CompressedChunkStoreForGC interface { + getManyCompressed(context.Context, hash.HashSet, func(context.Context, CompressedChunk), gcDependencyMode) error +} + type NomsBlockStore struct { mm manifestManager p tablePersister @@ -99,8 +110,14 @@ type NomsBlockStore struct { tables tableSet upstream manifestContents - cond *sync.Cond + cond *sync.Cond + // |true| after BeginGC is called, and false once the corresponding EndGC call returns. gcInProgress bool + // When unlocked read operations are occuring against the + // block store, and they started when |gcInProgress == true|, + // this variable is incremented. EndGC will not return until + // no outstanding reads are in progress. + gcOutstandingReads int // keeperFunc is set when |gcInProgress| and appends to the GC sweep queue // or blocks on GC finalize keeperFunc func(hash.Hash) bool @@ -152,14 +169,14 @@ func (nbs *NomsBlockStore) GetChunkLocationsWithPaths(ctx context.Context, hashe } func (nbs *NomsBlockStore) GetChunkLocations(ctx context.Context, hashes hash.HashSet) (map[hash.Hash]map[hash.Hash]Range, error) { - gr := toGetRecords(hashes) - ranges := make(map[hash.Hash]map[hash.Hash]Range) - - fn := func(css chunkSourceSet) error { + fn := func(css chunkSourceSet, gr []getRecord, ranges map[hash.Hash]map[hash.Hash]Range, keeper keeperF) (gcBehavior, error) { for _, cs := range css { - rng, err := cs.getRecordRanges(ctx, gr) + rng, gcb, err := cs.getRecordRanges(ctx, gr, keeper) if err != nil { - return err + return gcBehavior_Continue, err + } + if gcb != gcBehavior_Continue { + return gcb, nil } h := hash.Hash(cs.hash()) @@ -171,22 +188,60 @@ func (nbs *NomsBlockStore) GetChunkLocations(ctx context.Context, hashes hash.Ha ranges[h] = rng } } - return nil + return gcBehavior_Continue, nil } - tables := func() tableSet { - nbs.mu.RLock() - defer nbs.mu.RUnlock() - return nbs.tables - }() + for { + nbs.mu.Lock() + tables, keeper, endRead := nbs.tables, nbs.keeperFunc, nbs.beginRead() + nbs.mu.Unlock() - if err := fn(tables.upstream); err != nil { - return nil, err + gr := toGetRecords(hashes) + ranges := make(map[hash.Hash]map[hash.Hash]Range) + + gcb, err := fn(tables.upstream, gr, ranges, keeper) + if needsContinue, err := nbs.handleUnlockedRead(ctx, gcb, endRead, err); err != nil { + return nil, err + } else if needsContinue { + continue + } + + gcb, err = fn(tables.novel, gr, ranges, keeper) + if needsContinue, err := nbs.handleUnlockedRead(ctx, gcb, endRead, err); err != nil { + return nil, err + } else if needsContinue { + continue + } + + return ranges, nil } - if err := fn(tables.novel); err != nil { - return nil, err +} + +func (nbs *NomsBlockStore) handleUnlockedRead(ctx context.Context, gcb gcBehavior, endRead func(), err error) (bool, error) { + if err != nil { + if endRead != nil { + nbs.mu.Lock() + endRead() + nbs.mu.Unlock() + } + return false, err + } + if gcb == gcBehavior_Block { + nbs.mu.Lock() + if endRead != nil { + endRead() + } + err := nbs.waitForGC(ctx) + nbs.mu.Unlock() + return true, err + } else { + if endRead != nil { + nbs.mu.Lock() + endRead() + nbs.mu.Unlock() + } + return false, nil } - return ranges, nil } func (nbs *NomsBlockStore) conjoinIfRequired(ctx context.Context) (bool, error) { @@ -218,11 +273,6 @@ func (nbs *NomsBlockStore) conjoinIfRequired(ctx context.Context) (bool, error) func (nbs *NomsBlockStore) UpdateManifest(ctx context.Context, updates map[hash.Hash]uint32) (mi ManifestInfo, err error) { nbs.mu.Lock() defer nbs.mu.Unlock() - err = nbs.waitForGC(ctx) - if err != nil { - return - } - err = nbs.checkAllManifestUpdatesExist(ctx, updates) if err != nil { return @@ -306,11 +356,6 @@ func (nbs *NomsBlockStore) UpdateManifest(ctx context.Context, updates map[hash. func (nbs *NomsBlockStore) UpdateManifestWithAppendix(ctx context.Context, updates map[hash.Hash]uint32, option ManifestAppendixOption) (mi ManifestInfo, err error) { nbs.mu.Lock() defer nbs.mu.Unlock() - err = nbs.waitForGC(ctx) - if err != nil { - return - } - err = nbs.checkAllManifestUpdatesExist(ctx, updates) if err != nil { return @@ -462,11 +507,6 @@ func fromManifestAppendixOptionNewContents(upstream manifestContents, appendixSp func OverwriteStoreManifest(ctx context.Context, store *NomsBlockStore, root hash.Hash, tableFiles map[hash.Hash]uint32, appendixTableFiles map[hash.Hash]uint32) (err error) { store.mu.Lock() defer store.mu.Unlock() - err = store.waitForGC(ctx) - if err != nil { - return - } - contents := manifestContents{ root: root, nbfVers: store.upstream.nbfVers, @@ -703,7 +743,9 @@ func (nbs *NomsBlockStore) WithoutConjoiner() *NomsBlockStore { } } -// Wait for GC to complete to continue with writes +// Wait for GC to complete to continue with ongoing operations. +// Called with nbs.mu held. When this function returns with a nil +// error, gcInProgress will be false. func (nbs *NomsBlockStore) waitForGC(ctx context.Context) error { stop := make(chan struct{}) defer close(stop) @@ -721,7 +763,7 @@ func (nbs *NomsBlockStore) waitForGC(ctx context.Context) error { } func (nbs *NomsBlockStore) Put(ctx context.Context, c chunks.Chunk, getAddrs chunks.GetAddrsCurry) error { - return nbs.putChunk(ctx, c, getAddrs, nbs.hasMany) + return nbs.putChunk(ctx, c, getAddrs, nbs.refCheck) } func (nbs *NomsBlockStore) putChunk(ctx context.Context, c chunks.Chunk, getAddrs chunks.GetAddrsCurry, checker refCheck) error { @@ -787,11 +829,18 @@ func (nbs *NomsBlockStore) addChunk(ctx context.Context, ch chunks.Chunk, getAdd addChunkRes = nbs.mt.addChunk(ch.Hash(), ch.Data()) if addChunkRes == chunkNotAdded { - ts, err := nbs.tables.append(ctx, nbs.mt, checker, nbs.hasCache, nbs.stats) + ts, gcb, err := nbs.tables.append(ctx, nbs.mt, checker, nbs.keeperFunc, nbs.hasCache, nbs.stats) if err != nil { nbs.handlePossibleDanglingRefError(err) return false, err } + if gcb == gcBehavior_Block { + retry = true + if err := nbs.waitForGC(ctx); err != nil { + return false, err + } + continue + } nbs.addPendingRefsToHasCache() nbs.tables = ts nbs.mt = newMemTable(nbs.mtSize) @@ -845,100 +894,134 @@ func (nbs *NomsBlockStore) Get(ctx context.Context, h hash.Hash) (chunks.Chunk, nbs.stats.ChunksPerGet.Sample(1) }() - data, tables, err := func() ([]byte, chunkReader, error) { - var data []byte - nbs.mu.RLock() - defer nbs.mu.RUnlock() + for { + nbs.mu.Lock() if nbs.mt != nil { - var err error - data, err = nbs.mt.get(ctx, h, nbs.stats) - + data, gcb, err := nbs.mt.get(ctx, h, nbs.keeperFunc, nbs.stats) if err != nil { - return nil, nil, err + nbs.mu.Unlock() + return chunks.EmptyChunk, err + } + if gcb == gcBehavior_Block { + err = nbs.waitForGC(ctx) + nbs.mu.Unlock() + if err != nil { + return chunks.EmptyChunk, err + } + continue + } + if data != nil { + nbs.mu.Unlock() + return chunks.NewChunkWithHash(h, data), nil } } - return data, nbs.tables, nil - }() - - if err != nil { - return chunks.EmptyChunk, err - } + tables, keeper, endRead := nbs.tables, nbs.keeperFunc, nbs.beginRead() + nbs.mu.Unlock() - if data != nil { - return chunks.NewChunkWithHash(h, data), nil - } - - data, err = tables.get(ctx, h, nbs.stats) - - if err != nil { - return chunks.EmptyChunk, err - } + data, gcb, err := tables.get(ctx, h, keeper, nbs.stats) + needContinue, err := nbs.handleUnlockedRead(ctx, gcb, endRead, err) + if err != nil { + return chunks.EmptyChunk, err + } + if needContinue { + continue + } - if data != nil { - return chunks.NewChunkWithHash(h, data), nil + if data != nil { + return chunks.NewChunkWithHash(h, data), nil + } + return chunks.EmptyChunk, nil } - - return chunks.EmptyChunk, nil } func (nbs *NomsBlockStore) GetMany(ctx context.Context, hashes hash.HashSet, found func(context.Context, *chunks.Chunk)) error { ctx, span := tracer.Start(ctx, "nbs.GetMany", trace.WithAttributes(attribute.Int("num_hashes", len(hashes)))) - span.End() - return nbs.getManyWithFunc(ctx, hashes, func(ctx context.Context, cr chunkReader, eg *errgroup.Group, reqs []getRecord, stats *Stats) (bool, error) { - return cr.getMany(ctx, eg, reqs, found, nbs.stats) - }) + defer span.End() + return nbs.getManyWithFunc(ctx, hashes, gcDependencyMode_TakeDependency, + func(ctx context.Context, cr chunkReader, eg *errgroup.Group, reqs []getRecord, keeper keeperF, stats *Stats) (bool, gcBehavior, error) { + return cr.getMany(ctx, eg, reqs, found, keeper, nbs.stats) + }, + ) } func (nbs *NomsBlockStore) GetManyCompressed(ctx context.Context, hashes hash.HashSet, found func(context.Context, CompressedChunk)) error { + return nbs.getManyCompressed(ctx, hashes, found, gcDependencyMode_TakeDependency) +} + +func (nbs *NomsBlockStore) getManyCompressed(ctx context.Context, hashes hash.HashSet, found func(context.Context, CompressedChunk), gcDepMode gcDependencyMode) error { ctx, span := tracer.Start(ctx, "nbs.GetManyCompressed", trace.WithAttributes(attribute.Int("num_hashes", len(hashes)))) defer span.End() - return nbs.getManyWithFunc(ctx, hashes, func(ctx context.Context, cr chunkReader, eg *errgroup.Group, reqs []getRecord, stats *Stats) (bool, error) { - return cr.getManyCompressed(ctx, eg, reqs, found, nbs.stats) - }) + return nbs.getManyWithFunc(ctx, hashes, gcDepMode, + func(ctx context.Context, cr chunkReader, eg *errgroup.Group, reqs []getRecord, keeper keeperF, stats *Stats) (bool, gcBehavior, error) { + return cr.getManyCompressed(ctx, eg, reqs, found, keeper, nbs.stats) + }, + ) } func (nbs *NomsBlockStore) getManyWithFunc( ctx context.Context, hashes hash.HashSet, - getManyFunc func(ctx context.Context, cr chunkReader, eg *errgroup.Group, reqs []getRecord, stats *Stats) (bool, error), + gcDepMode gcDependencyMode, + getManyFunc func(ctx context.Context, cr chunkReader, eg *errgroup.Group, reqs []getRecord, keeper keeperF, stats *Stats) (bool, gcBehavior, error), ) error { - t1 := time.Now() - reqs := toGetRecords(hashes) + if len(hashes) == 0 { + return nil + } + t1 := time.Now() defer func() { - if len(hashes) > 0 { - nbs.stats.GetLatency.SampleTimeSince(t1) - nbs.stats.ChunksPerGet.Sample(uint64(len(reqs))) - } + nbs.stats.GetLatency.SampleTimeSince(t1) + nbs.stats.ChunksPerGet.Sample(uint64(len(hashes))) }() - eg, ctx := errgroup.WithContext(ctx) const ioParallelism = 16 - eg.SetLimit(ioParallelism) + for { + reqs := toGetRecords(hashes) - tables, remaining, err := func() (tables chunkReader, remaining bool, err error) { - nbs.mu.RLock() - defer nbs.mu.RUnlock() - tables = nbs.tables - remaining = true + nbs.mu.Lock() + keeper := nbs.keeperFunc + if gcDepMode == gcDependencyMode_NoDependency { + keeper = nil + } if nbs.mt != nil { - remaining, err = getManyFunc(ctx, nbs.mt, eg, reqs, nbs.stats) + // nbs.mt does not use the errgroup parameter, which we pass at |nil| here. + remaining, gcb, err := getManyFunc(ctx, nbs.mt, nil, reqs, keeper, nbs.stats) + if err != nil { + nbs.mu.Unlock() + return err + } + if gcb == gcBehavior_Block { + err = nbs.waitForGC(ctx) + nbs.mu.Unlock() + if err != nil { + return err + } + continue + } + if !remaining { + nbs.mu.Unlock() + return nil + } } - return - }() - if err != nil { - return err - } + tables, endRead := nbs.tables, nbs.beginRead() + nbs.mu.Unlock() - if remaining { - _, err = getManyFunc(ctx, tables, eg, reqs, nbs.stats) - } + gcb, err := func() (gcBehavior, error) { + eg, ctx := errgroup.WithContext(ctx) + eg.SetLimit(ioParallelism) + _, gcb, err := getManyFunc(ctx, tables, eg, reqs, keeper, nbs.stats) + return gcb, errors.Join(err, eg.Wait()) + }() + needContinue, err := nbs.handleUnlockedRead(ctx, gcb, endRead, err) + if err != nil { + return err + } + if needContinue { + continue + } - if err != nil { - eg.Wait() - return err + return nil } - return eg.Wait() } func toGetRecords(hashes hash.HashSet) []getRecord { @@ -992,73 +1075,138 @@ func (nbs *NomsBlockStore) Has(ctx context.Context, h hash.Hash) (bool, error) { nbs.stats.AddressesPerHas.Sample(1) }() - has, tables, err := func() (bool, chunkReader, error) { - nbs.mu.RLock() - defer nbs.mu.RUnlock() - + for { + nbs.mu.Lock() if nbs.mt != nil { - has, err := nbs.mt.has(h) - + has, gcb, err := nbs.mt.has(h, nbs.keeperFunc) if err != nil { - return false, nil, err + nbs.mu.Unlock() + return false, err + } + if gcb == gcBehavior_Block { + err = nbs.waitForGC(ctx) + nbs.mu.Unlock() + if err != nil { + return false, err + } + continue + } + if has { + nbs.mu.Unlock() + return true, nil } - - return has, nbs.tables, nil } + tables, keeper, endRead := nbs.tables, nbs.keeperFunc, nbs.beginRead() + nbs.mu.Unlock() - return false, nbs.tables, nil - }() - - if err != nil { - return false, err - } - - if !has { - has, err = tables.has(h) - + has, gcb, err := tables.has(h, keeper) + needsContinue, err := nbs.handleUnlockedRead(ctx, gcb, endRead, err) if err != nil { return false, err } - } + if needsContinue { + continue + } - return has, nil + return has, nil + } } func (nbs *NomsBlockStore) HasMany(ctx context.Context, hashes hash.HashSet) (hash.HashSet, error) { + return nbs.hasManyDep(ctx, hashes, gcDependencyMode_TakeDependency) +} + +func (nbs *NomsBlockStore) hasManyDep(ctx context.Context, hashes hash.HashSet, gcDepMode gcDependencyMode) (hash.HashSet, error) { if hashes.Size() == 0 { return nil, nil } t1 := time.Now() - defer nbs.stats.HasLatency.SampleTimeSince(t1) - nbs.stats.AddressesPerHas.SampleLen(hashes.Size()) + defer func() { + nbs.stats.HasLatency.SampleTimeSince(t1) + nbs.stats.AddressesPerHas.SampleLen(hashes.Size()) + }() - nbs.mu.RLock() - defer nbs.mu.RUnlock() - return nbs.hasMany(toHasRecords(hashes)) -} + for { + reqs := toHasRecords(hashes) -func (nbs *NomsBlockStore) hasManyInSources(srcs []hash.Hash, hashes hash.HashSet) (hash.HashSet, error) { - if hashes.Size() == 0 { - return nil, nil - } + nbs.mu.Lock() + if nbs.mt != nil { + keeper := nbs.keeperFunc + if gcDepMode == gcDependencyMode_NoDependency { + keeper = nil + } + remaining, gcb, err := nbs.mt.hasMany(reqs, keeper) + if err != nil { + nbs.mu.Unlock() + return nil, err + } + if gcb == gcBehavior_Block { + err = nbs.waitForGC(ctx) + nbs.mu.Unlock() + if err != nil { + return nil, err + } + continue + } + if !remaining { + nbs.mu.Unlock() + return hash.HashSet{}, nil + } + } + tables, keeper, endRead := nbs.tables, nbs.keeperFunc, nbs.beginRead() + if gcDepMode == gcDependencyMode_NoDependency { + keeper = nil + } + nbs.mu.Unlock() - t1 := time.Now() - defer nbs.stats.HasLatency.SampleTimeSince(t1) - nbs.stats.AddressesPerHas.SampleLen(hashes.Size()) + remaining, gcb, err := tables.hasMany(reqs, keeper) + needContinue, err := nbs.handleUnlockedRead(ctx, gcb, endRead, err) + if err != nil { + return nil, err + } + if needContinue { + continue + } - nbs.mu.RLock() - defer nbs.mu.RUnlock() + if !remaining { + return hash.HashSet{}, nil + } - records := toHasRecords(hashes) + absent := hash.HashSet{} + for _, r := range reqs { + if !r.has { + absent.Insert(*r.a) + } + } + return absent, nil + } +} + +// Operates a lot like |hasMany|, but without locking and without +// taking read dependencies on the checked references. Should only be +// used for the sanity checking on references for written chunks. +func (nbs *NomsBlockStore) refCheck(reqs []hasRecord) (hash.HashSet, error) { + if nbs.mt != nil { + remaining, _, err := nbs.mt.hasMany(reqs, nil) + if err != nil { + return nil, err + } + if !remaining { + return hash.HashSet{}, nil + } + } - _, err := nbs.tables.hasManyInSources(srcs, records) + remaining, _, err := nbs.tables.hasMany(reqs, nil) if err != nil { return nil, err } + if !remaining { + return hash.HashSet{}, nil + } absent := hash.HashSet{} - for _, r := range records { + for _, r := range reqs { if !r.has { absent.Insert(*r.a) } @@ -1066,36 +1214,32 @@ func (nbs *NomsBlockStore) hasManyInSources(srcs []hash.Hash, hashes hash.HashSe return absent, nil } -func (nbs *NomsBlockStore) hasMany(reqs []hasRecord) (hash.HashSet, error) { - tables, remaining, err := func() (tables chunkReader, remaining bool, err error) { - tables = nbs.tables +// Only used for a generational full GC, where the table files are +// added to the store and are then used to filter which chunks need to +// make it to the new generation. In this context, we do not need to +// worry about taking read dependencies on the requested chunks. Hence +// our handling of keeperFunc and gcBehavior below. +func (nbs *NomsBlockStore) hasManyInSources(srcs []hash.Hash, hashes hash.HashSet) (hash.HashSet, error) { + if hashes.Size() == 0 { + return nil, nil + } - remaining = true - if nbs.mt != nil { - remaining, err = nbs.mt.hasMany(reqs) + t1 := time.Now() + defer nbs.stats.HasLatency.SampleTimeSince(t1) + nbs.stats.AddressesPerHas.SampleLen(hashes.Size()) - if err != nil { - return nil, false, err - } - } + nbs.mu.RLock() + defer nbs.mu.RUnlock() - return tables, remaining, nil - }() + records := toHasRecords(hashes) + _, _, err := nbs.tables.hasManyInSources(srcs, records, nil) if err != nil { return nil, err } - if remaining { - _, err := tables.hasMany(reqs) - - if err != nil { - return nil, err - } - } - absent := hash.HashSet{} - for _, r := range reqs { + for _, r := range records { if !r.has { absent.Insert(*r.a) } @@ -1162,7 +1306,7 @@ func (nbs *NomsBlockStore) Root(ctx context.Context) (hash.Hash, error) { } func (nbs *NomsBlockStore) Commit(ctx context.Context, current, last hash.Hash) (success bool, err error) { - return nbs.commit(ctx, current, last, nbs.hasMany) + return nbs.commit(ctx, current, last, nbs.refCheck) } func (nbs *NomsBlockStore) commit(ctx context.Context, current, last hash.Hash, checker refCheck) (success bool, err error) { @@ -1251,22 +1395,30 @@ func (nbs *NomsBlockStore) updateManifest(ctx context.Context, current, last has return handleOptimisticLockFailure(cached) } - if nbs.mt != nil { - cnt, err := nbs.mt.count() - - if err != nil { - return err - } - - if cnt > 0 { - ts, err := nbs.tables.append(ctx, nbs.mt, checker, nbs.hasCache, nbs.stats) + for { + if nbs.mt != nil { + cnt, err := nbs.mt.count() if err != nil { - nbs.handlePossibleDanglingRefError(err) return err } - nbs.addPendingRefsToHasCache() - nbs.tables, nbs.mt = ts, nil + if cnt > 0 { + ts, gcb, err := nbs.tables.append(ctx, nbs.mt, checker, nbs.keeperFunc, nbs.hasCache, nbs.stats) + if err != nil { + nbs.handlePossibleDanglingRefError(err) + return err + } + if gcb == gcBehavior_Block { + err = nbs.waitForGC(ctx) + if err != nil { + return err + } + continue + } + nbs.addPendingRefsToHasCache() + nbs.tables, nbs.mt = ts, nil + } } + break } didConjoin, err := nbs.conjoinIfRequired(ctx) @@ -1555,12 +1707,11 @@ func (nbs *NomsBlockStore) AddTableFilesToManifest(ctx context.Context, fileIdTo // PruneTableFiles deletes old table files that are no longer referenced in the manifest. func (nbs *NomsBlockStore) PruneTableFiles(ctx context.Context) (err error) { - return nbs.pruneTableFiles(ctx, nbs.hasMany) + return nbs.pruneTableFiles(ctx) } -func (nbs *NomsBlockStore) pruneTableFiles(ctx context.Context, checker refCheck) (err error) { +func (nbs *NomsBlockStore) pruneTableFiles(ctx context.Context) (err error) { mtime := time.Now() - return nbs.p.PruneTableFiles(ctx, func() []hash.Hash { nbs.mu.Lock() defer nbs.mu.Unlock() @@ -1575,7 +1726,7 @@ func (nbs *NomsBlockStore) pruneTableFiles(ctx context.Context, checker refCheck }, mtime) } -func (nbs *NomsBlockStore) BeginGC(keeper func(hash.Hash) bool) error { +func (nbs *NomsBlockStore) BeginGC(keeper func(hash.Hash) bool, _ chunks.GCMode) error { nbs.cond.L.Lock() defer nbs.cond.L.Unlock() if nbs.gcInProgress { @@ -1587,22 +1738,47 @@ func (nbs *NomsBlockStore) BeginGC(keeper func(hash.Hash) bool) error { return nil } -func (nbs *NomsBlockStore) EndGC() { +func (nbs *NomsBlockStore) EndGC(_ chunks.GCMode) { nbs.cond.L.Lock() defer nbs.cond.L.Unlock() if !nbs.gcInProgress { panic("EndGC called when gc was not in progress") } + for nbs.gcOutstandingReads > 0 { + nbs.cond.Wait() + } nbs.gcInProgress = false nbs.keeperFunc = nil nbs.cond.Broadcast() } +// beginRead() is called with |nbs.mu| held. It signals an ongoing +// read operation which will be operating against the existing table +// files without |nbs.mu| held. The read should be bracket with a call +// to the returned |endRead|, which must be called with |nbs.mu| held +// if it is non-|nil|, and should not be called otherwise. +// +// If there is an ongoing GC operation which this call is made, it is +// guaranteed not to complete until the corresponding |endRead| call. +func (nbs *NomsBlockStore) beginRead() (endRead func()) { + if nbs.gcInProgress { + nbs.gcOutstandingReads += 1 + return func() { + nbs.gcOutstandingReads -= 1 + if nbs.gcOutstandingReads < 0 { + panic("impossible") + } + nbs.cond.Broadcast() + } + } + return nil +} + func (nbs *NomsBlockStore) MarkAndSweepChunks(ctx context.Context, getAddrs chunks.GetAddrsCurry, filter chunks.HasManyFunc, dest chunks.ChunkStore, mode chunks.GCMode) (chunks.MarkAndSweeper, error) { return markAndSweepChunks(ctx, nbs, nbs, dest, getAddrs, filter, mode) } -func markAndSweepChunks(ctx context.Context, nbs *NomsBlockStore, src NBSCompressedChunkStore, dest chunks.ChunkStore, getAddrs chunks.GetAddrsCurry, filter chunks.HasManyFunc, mode chunks.GCMode) (chunks.MarkAndSweeper, error) { +func markAndSweepChunks(ctx context.Context, nbs *NomsBlockStore, src CompressedChunkStoreForGC, dest chunks.ChunkStore, getAddrs chunks.GetAddrsCurry, filter chunks.HasManyFunc, mode chunks.GCMode) (chunks.MarkAndSweeper, error) { ops := nbs.SupportedOperations() if !ops.CanGC || !ops.CanPrune { return nil, chunks.ErrUnsupportedOperation @@ -1670,7 +1846,7 @@ func markAndSweepChunks(ctx context.Context, nbs *NomsBlockStore, src NBSCompres } type markAndSweeper struct { - src NBSCompressedChunkStore + src CompressedChunkStoreForGC dest *NomsBlockStore getAddrs chunks.GetAddrsCurry filter chunks.HasManyFunc @@ -1716,7 +1892,7 @@ func (i *markAndSweeper) SaveHashes(ctx context.Context, hashes []hash.Hash) err found := 0 var addErr error - err = i.src.GetManyCompressed(ctx, toVisit, func(ctx context.Context, cc CompressedChunk) { + err = i.src.getManyCompressed(ctx, toVisit, func(ctx context.Context, cc CompressedChunk) { mu.Lock() defer mu.Unlock() if addErr != nil { @@ -1740,7 +1916,7 @@ func (i *markAndSweeper) SaveHashes(ctx context.Context, hashes []hash.Hash) err return } addErr = i.getAddrs(c)(ctx, nextToVisit, func(h hash.Hash) bool { return false }) - }) + }, gcDependencyMode_NoDependency) if err != nil { return err } @@ -1905,7 +2081,7 @@ func (nbs *NomsBlockStore) swapTables(ctx context.Context, specs []tableSpec, mo // SetRootChunk changes the root chunk hash from the previous value to the new root. func (nbs *NomsBlockStore) SetRootChunk(ctx context.Context, root, previous hash.Hash) error { - return nbs.setRootChunk(ctx, root, previous, nbs.hasMany) + return nbs.setRootChunk(ctx, root, previous, nbs.refCheck) } func (nbs *NomsBlockStore) setRootChunk(ctx context.Context, root, previous hash.Hash, checker refCheck) error { @@ -1932,7 +2108,7 @@ func (nbs *NomsBlockStore) setRootChunk(ctx context.Context, root, previous hash } // CalcReads computes the number of IO operations necessary to fetch |hashes|. -func CalcReads(nbs *NomsBlockStore, hashes hash.HashSet, blockSize uint64) (reads int, split bool, err error) { +func CalcReads(nbs *NomsBlockStore, hashes hash.HashSet, blockSize uint64, keeper keeperF) (int, bool, gcBehavior, error) { reqs := toGetRecords(hashes) tables := func() (tables tableSet) { nbs.mu.RLock() @@ -1942,15 +2118,17 @@ func CalcReads(nbs *NomsBlockStore, hashes hash.HashSet, blockSize uint64) (read return }() - reads, split, remaining, err := tableSetCalcReads(tables, reqs, blockSize) - + reads, split, remaining, gcb, err := tableSetCalcReads(tables, reqs, blockSize, keeper) if err != nil { - return 0, false, err + return 0, false, gcb, err + } + if gcb != gcBehavior_Continue { + return 0, false, gcb, nil } if remaining { - return 0, false, errors.New("failed to find all chunks") + return 0, false, gcBehavior_Continue, errors.New("failed to find all chunks") } - return + return reads, split, gcb, err } diff --git a/go/store/nbs/store_test.go b/go/store/nbs/store_test.go index 90f204bc996..ffcb5459546 100644 --- a/go/store/nbs/store_test.go +++ b/go/store/nbs/store_test.go @@ -44,6 +44,9 @@ func makeTestLocalStore(t *testing.T, maxTableFiles int) (st *NomsBlockStore, no nomsDir = filepath.Join(tempfiles.MovableTempFileProvider.GetTempDir(), "noms_"+uuid.New().String()[:8]) err := os.MkdirAll(nomsDir, os.ModePerm) require.NoError(t, err) + t.Cleanup(func() { + os.RemoveAll(nomsDir) + }) // create a v5 manifest fm, err := getFileManifest(ctx, nomsDir, asyncFlush) @@ -209,7 +212,7 @@ func TestNBSPruneTableFiles(t *testing.T) { addrs.Insert(c.Hash()) return nil } - }, st.hasMany) + }, st.refCheck) require.NoError(t, err) require.True(t, ok) ok, err = st.Commit(ctx, st.upstream.root, st.upstream.root) @@ -334,7 +337,7 @@ func TestNBSCopyGC(t *testing.T) { require.NoError(t, err) require.True(t, ok) - require.NoError(t, st.BeginGC(nil)) + require.NoError(t, st.BeginGC(nil, chunks.GCMode_Full)) noopFilter := func(ctx context.Context, hashes hash.HashSet) (hash.HashSet, error) { return hashes, nil } @@ -349,7 +352,7 @@ func TestNBSCopyGC(t *testing.T) { require.NoError(t, err) require.NoError(t, sweeper.Close(ctx)) require.NoError(t, finalizer.SwapChunksInStore(ctx)) - st.EndGC() + st.EndGC(chunks.GCMode_Full) for h, c := range keepers { out, err := st.Get(ctx, h) @@ -378,7 +381,7 @@ func persistTableFileSources(t *testing.T, p tablePersister, numTableFiles int) require.True(t, ok) tableFileMap[fileIDHash] = uint32(i + 1) mapIds[i] = fileIDHash - cs, err := p.Persist(context.Background(), createMemTable(chunkData), nil, &Stats{}) + cs, _, err := p.Persist(context.Background(), createMemTable(chunkData), nil, nil, &Stats{}) require.NoError(t, err) require.NoError(t, cs.close()) diff --git a/go/store/nbs/table.go b/go/store/nbs/table.go index b487422d079..c4763989734 100644 --- a/go/store/nbs/table.go +++ b/go/store/nbs/table.go @@ -187,24 +187,41 @@ type extractRecord struct { err error } +// Returned by read methods that take a |keeperFunc|, this lets a +// caller know whether the operation was successful or if it needs to +// be retried. It may need to be retried if a GC is in progress but +// the dependencies indicated by the operation cannot be added to the +// GC process. In that case, the caller needs to wait until the GC is +// over and run the entire operation again. +type gcBehavior bool + +const ( + // Operation was successful, go forward with the result. + gcBehavior_Continue gcBehavior = false + // Operation needs to block until the GC is over and then retry. + gcBehavior_Block = true +) + +type keeperF func(hash.Hash) bool + type chunkReader interface { // has returns true if a chunk with addr |h| is present. - has(h hash.Hash) (bool, error) + has(h hash.Hash, keeper keeperF) (bool, gcBehavior, error) // hasMany sets hasRecord.has to true for each present hasRecord query, it returns // true if any hasRecord query was not found in this chunkReader. - hasMany(addrs []hasRecord) (bool, error) + hasMany(addrs []hasRecord, keeper keeperF) (bool, gcBehavior, error) // get returns the chunk data for a chunk with addr |h| if present, and nil otherwise. - get(ctx context.Context, h hash.Hash, stats *Stats) ([]byte, error) + get(ctx context.Context, h hash.Hash, keeper keeperF, stats *Stats) ([]byte, gcBehavior, error) // getMany sets getRecord.found to true, and calls |found| for each present getRecord query. // It returns true if any getRecord query was not found in this chunkReader. - getMany(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), stats *Stats) (bool, error) + getMany(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), keeper keeperF, stats *Stats) (bool, gcBehavior, error) // getManyCompressed sets getRecord.found to true, and calls |found| for each present getRecord query. // It returns true if any getRecord query was not found in this chunkReader. - getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), stats *Stats) (bool, error) + getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), keeper keeperF, stats *Stats) (bool, gcBehavior, error) // count returns the chunk count for this chunkReader. count() (uint32, error) @@ -226,7 +243,7 @@ type chunkSource interface { reader(context.Context) (io.ReadCloser, uint64, error) // getRecordRanges sets getRecord.found to true, and returns a Range for each present getRecord query. - getRecordRanges(ctx context.Context, requests []getRecord) (map[hash.Hash]Range, error) + getRecordRanges(ctx context.Context, requests []getRecord, keeper keeperF) (map[hash.Hash]Range, gcBehavior, error) // index returns the tableIndex of this chunkSource. index() (tableIndex, error) diff --git a/go/store/nbs/table_persister.go b/go/store/nbs/table_persister.go index 5c230daa091..6d283fb3ec5 100644 --- a/go/store/nbs/table_persister.go +++ b/go/store/nbs/table_persister.go @@ -47,7 +47,7 @@ type cleanupFunc func() type tablePersister interface { // Persist makes the contents of mt durable. Chunks already present in // |haver| may be dropped in the process. - Persist(ctx context.Context, mt *memTable, haver chunkReader, stats *Stats) (chunkSource, error) + Persist(ctx context.Context, mt *memTable, haver chunkReader, keeper keeperF, stats *Stats) (chunkSource, gcBehavior, error) // ConjoinAll conjoins all chunks in |sources| into a single, new // chunkSource. It returns a |cleanupFunc| which can be called to diff --git a/go/store/nbs/table_reader.go b/go/store/nbs/table_reader.go index 3ff059fb480..c55d48c28a4 100644 --- a/go/store/nbs/table_reader.go +++ b/go/store/nbs/table_reader.go @@ -178,7 +178,7 @@ func newTableReader(index tableIndex, r tableReaderAt, blockSize uint64) (tableR } // Scan across (logically) two ordered slices of address prefixes. -func (tr tableReader) hasMany(addrs []hasRecord) (bool, error) { +func (tr tableReader) hasMany(addrs []hasRecord, keeper keeperF) (bool, gcBehavior, error) { filterIdx := uint32(0) filterLen := uint32(tr.idx.chunkCount()) @@ -206,7 +206,7 @@ func (tr tableReader) hasMany(addrs []hasRecord) (bool, error) { } if filterIdx >= filterLen { - return true, nil + return true, gcBehavior_Continue, nil } if addr.prefix != tr.prefixes[filterIdx] { @@ -218,9 +218,12 @@ func (tr tableReader) hasMany(addrs []hasRecord) (bool, error) { for j := filterIdx; j < filterLen && addr.prefix == tr.prefixes[j]; j++ { m, err := tr.idx.entrySuffixMatches(j, addr.a) if err != nil { - return false, err + return false, gcBehavior_Continue, err } if m { + if keeper != nil && keeper(*addr.a) { + return true, gcBehavior_Block, nil + } addrs[i].has = true break } @@ -231,7 +234,7 @@ func (tr tableReader) hasMany(addrs []hasRecord) (bool, error) { } } - return remaining, nil + return remaining, gcBehavior_Continue, nil } func (tr tableReader) count() (uint32, error) { @@ -247,20 +250,27 @@ func (tr tableReader) index() (tableIndex, error) { } // returns true iff |h| can be found in this table. -func (tr tableReader) has(h hash.Hash) (bool, error) { +func (tr tableReader) has(h hash.Hash, keeper keeperF) (bool, gcBehavior, error) { _, ok, err := tr.idx.lookup(&h) - return ok, err + if ok && keeper != nil && keeper(h) { + return false, gcBehavior_Block, nil + } + return ok, gcBehavior_Continue, err } // returns the storage associated with |h|, iff present. Returns nil if absent. On success, // the returned byte slice directly references the underlying storage. -func (tr tableReader) get(ctx context.Context, h hash.Hash, stats *Stats) ([]byte, error) { +func (tr tableReader) get(ctx context.Context, h hash.Hash, keeper keeperF, stats *Stats) ([]byte, gcBehavior, error) { e, found, err := tr.idx.lookup(&h) if err != nil { - return nil, err + return nil, gcBehavior_Continue, err } if !found { - return nil, nil + return nil, gcBehavior_Continue, nil + } + + if keeper != nil && keeper(h) { + return nil, gcBehavior_Block, nil } offset := e.Offset() @@ -270,30 +280,30 @@ func (tr tableReader) get(ctx context.Context, h hash.Hash, stats *Stats) ([]byt n, err := tr.r.ReadAtWithStats(ctx, buff, int64(offset), stats) if err != nil { - return nil, err + return nil, gcBehavior_Continue, err } if n != int(length) { - return nil, errors.New("failed to read all data") + return nil, gcBehavior_Continue, errors.New("failed to read all data") } cmp, err := NewCompressedChunk(h, buff) if err != nil { - return nil, err + return nil, gcBehavior_Continue, err } if len(cmp.CompressedData) == 0 { - return nil, errors.New("failed to get data") + return nil, gcBehavior_Continue, errors.New("failed to get data") } chnk, err := cmp.ToChunk() if err != nil { - return nil, err + return nil, gcBehavior_Continue, err } - return chnk.Data(), nil + return chnk.Data(), gcBehavior_Continue, nil } type offsetRec struct { @@ -380,26 +390,33 @@ func (tr tableReader) getMany( eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), - stats *Stats) (bool, error) { + keeper keeperF, + stats *Stats) (bool, gcBehavior, error) { // Pass #1: Iterate over |reqs| and |tr.prefixes| (both sorted by address) and build the set // of table locations which must be read in order to satisfy the getMany operation. - offsetRecords, remaining, err := tr.findOffsets(reqs) + offsetRecords, remaining, gcb, err := tr.findOffsets(reqs, keeper) if err != nil { - return false, err + return false, gcBehavior_Continue, err + } + if gcb != gcBehavior_Continue { + return remaining, gcb, nil } err = tr.getManyAtOffsets(ctx, eg, offsetRecords, found, stats) - return remaining, err + return remaining, gcBehavior_Continue, err } -func (tr tableReader) getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), stats *Stats) (bool, error) { +func (tr tableReader) getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), keeper keeperF, stats *Stats) (bool, gcBehavior, error) { // Pass #1: Iterate over |reqs| and |tr.prefixes| (both sorted by address) and build the set // of table locations which must be read in order to satisfy the getMany operation. - offsetRecords, remaining, err := tr.findOffsets(reqs) + offsetRecords, remaining, gcb, err := tr.findOffsets(reqs, keeper) if err != nil { - return false, err + return false, gcb, err + } + if gcb != gcBehavior_Continue { + return remaining, gcb, nil } err = tr.getManyCompressedAtOffsets(ctx, eg, offsetRecords, found, stats) - return remaining, err + return remaining, gcBehavior_Continue, err } func (tr tableReader) getManyCompressedAtOffsets(ctx context.Context, eg *errgroup.Group, offsetRecords offsetRecSlice, found func(context.Context, CompressedChunk), stats *Stats) error { @@ -498,7 +515,7 @@ func (tr tableReader) getManyAtOffsetsWithReadFunc( // chunks remaining will be set to false upon return. If some are not here, // then remaining will be true. The result offsetRecSlice is sorted in offset // order. -func (tr tableReader) findOffsets(reqs []getRecord) (ors offsetRecSlice, remaining bool, err error) { +func (tr tableReader) findOffsets(reqs []getRecord, keeper keeperF) (ors offsetRecSlice, remaining bool, gcb gcBehavior, err error) { filterIdx := uint32(0) filterLen := uint32(len(tr.prefixes)) ors = make(offsetRecSlice, 0, len(reqs)) @@ -541,13 +558,16 @@ func (tr tableReader) findOffsets(reqs []getRecord) (ors offsetRecSlice, remaini for j := filterIdx; j < filterLen && req.prefix == tr.prefixes[j]; j++ { m, err := tr.idx.entrySuffixMatches(j, req.a) if err != nil { - return nil, false, err + return nil, false, gcBehavior_Continue, err } if m { + if keeper != nil && keeper(*req.a) { + return nil, false, gcBehavior_Block, nil + } reqs[i].found = true entry, err := tr.idx.indexEntry(j, nil) if err != nil { - return nil, false, err + return nil, false, gcBehavior_Continue, err } ors = append(ors, offsetRec{req.a, entry.Offset(), entry.Length()}) break @@ -560,7 +580,7 @@ func (tr tableReader) findOffsets(reqs []getRecord) (ors offsetRecSlice, remaini } sort.Sort(ors) - return ors, remaining, nil + return ors, remaining, gcBehavior_Continue, nil } func canReadAhead(fRec offsetRec, curStart, curEnd, blockSize uint64) (newEnd uint64, canRead bool) { @@ -584,12 +604,15 @@ func canReadAhead(fRec offsetRec, curStart, curEnd, blockSize uint64) (newEnd ui return fRec.offset + uint64(fRec.length), true } -func (tr tableReader) calcReads(reqs []getRecord, blockSize uint64) (reads int, remaining bool, err error) { +func (tr tableReader) calcReads(reqs []getRecord, blockSize uint64, keeper keeperF) (int, bool, gcBehavior, error) { var offsetRecords offsetRecSlice // Pass #1: Build the set of table locations which must be read in order to find all the elements of |reqs| which are present in this table. - offsetRecords, remaining, err = tr.findOffsets(reqs) + offsetRecords, remaining, gcb, err := tr.findOffsets(reqs, keeper) if err != nil { - return 0, false, err + return 0, false, gcb, err + } + if gcb != gcBehavior_Continue { + return 0, false, gcb, nil } // Now |offsetRecords| contains all locations within the table which must @@ -597,6 +620,7 @@ func (tr tableReader) calcReads(reqs []getRecord, blockSize uint64) (reads int, // location). Scan forward, grouping sequences of reads into large physical // reads. + var reads int var readStart, readEnd uint64 readStarted := false @@ -622,7 +646,7 @@ func (tr tableReader) calcReads(reqs []getRecord, blockSize uint64) (reads int, readStarted = false } - return + return reads, remaining, gcBehavior_Continue, err } func (tr tableReader) extract(ctx context.Context, chunks chan<- extractRecord) error { @@ -681,11 +705,14 @@ func (tr tableReader) reader(ctx context.Context) (io.ReadCloser, uint64, error) return r, sz, nil } -func (tr tableReader) getRecordRanges(ctx context.Context, requests []getRecord) (map[hash.Hash]Range, error) { +func (tr tableReader) getRecordRanges(ctx context.Context, requests []getRecord, keeper keeperF) (map[hash.Hash]Range, gcBehavior, error) { // findOffsets sets getRecord.found - recs, _, err := tr.findOffsets(requests) + recs, _, gcb, err := tr.findOffsets(requests, keeper) if err != nil { - return nil, err + return nil, gcb, err + } + if gcb != gcBehavior_Continue { + return nil, gcb, nil } ranges := make(map[hash.Hash]Range, len(recs)) for _, r := range recs { @@ -694,7 +721,7 @@ func (tr tableReader) getRecordRanges(ctx context.Context, requests []getRecord) Length: r.length, } } - return ranges, nil + return ranges, gcBehavior_Continue, nil } func (tr tableReader) currentSize() uint64 { diff --git a/go/store/nbs/table_set.go b/go/store/nbs/table_set.go index 185743199a0..88fd92de587 100644 --- a/go/store/nbs/table_set.go +++ b/go/store/nbs/table_set.go @@ -58,58 +58,62 @@ type tableSet struct { rl chan struct{} } -func (ts tableSet) has(h hash.Hash) (bool, error) { - f := func(css chunkSourceSet) (bool, error) { +func (ts tableSet) has(h hash.Hash, keeper keeperF) (bool, gcBehavior, error) { + f := func(css chunkSourceSet) (bool, gcBehavior, error) { for _, haver := range css { - has, err := haver.has(h) - + has, gcb, err := haver.has(h, keeper) if err != nil { - return false, err + return false, gcb, err + } + if gcb != gcBehavior_Continue { + return false, gcb, nil } - if has { - return true, nil + return true, gcBehavior_Continue, nil } } - return false, nil + return false, gcBehavior_Continue, nil } - novelHas, err := f(ts.novel) - + novelHas, gcb, err := f(ts.novel) if err != nil { - return false, err + return false, gcb, err + } + if gcb != gcBehavior_Continue { + return false, gcb, nil } - if novelHas { - return true, nil + return true, gcBehavior_Continue, nil } return f(ts.upstream) } -func (ts tableSet) hasMany(addrs []hasRecord) (bool, error) { - f := func(css chunkSourceSet) (bool, error) { +func (ts tableSet) hasMany(addrs []hasRecord, keeper keeperF) (bool, gcBehavior, error) { + f := func(css chunkSourceSet) (bool, gcBehavior, error) { for _, haver := range css { - has, err := haver.hasMany(addrs) - + has, gcb, err := haver.hasMany(addrs, keeper) if err != nil { - return false, err + return false, gcb, err + } + if gcb != gcBehavior_Continue { + return false, gcb, nil } - if !has { - return false, nil + return false, gcBehavior_Continue, nil } } - return true, nil + return true, gcBehavior_Continue, nil } - remaining, err := f(ts.novel) - + remaining, gcb, err := f(ts.novel) if err != nil { - return false, err + return false, gcb, err + } + if gcb != gcBehavior_Continue { + return remaining, gcb, err } - if !remaining { - return false, nil + return false, gcBehavior_Continue, nil } return f(ts.upstream) @@ -124,7 +128,10 @@ func (ts tableSet) hasMany(addrs []hasRecord) (bool, error) { // consulted. Only used for part of the GC workflow where we want to have // access to all chunks in the store but need to check for existing chunk // presence in only a subset of its files. -func (ts tableSet) hasManyInSources(srcs []hash.Hash, addrs []hasRecord) (remaining bool, err error) { +func (ts tableSet) hasManyInSources(srcs []hash.Hash, addrs []hasRecord, keeper keeperF) (bool, gcBehavior, error) { + var remaining bool + var err error + var gcb gcBehavior for _, rec := range addrs { if !rec.has { remaining = true @@ -132,7 +139,7 @@ func (ts tableSet) hasManyInSources(srcs []hash.Hash, addrs []hasRecord) (remain } } if !remaining { - return false, nil + return false, gcBehavior_Continue, nil } for _, srcAddr := range srcs { src, ok := ts.novel[srcAddr] @@ -142,83 +149,114 @@ func (ts tableSet) hasManyInSources(srcs []hash.Hash, addrs []hasRecord) (remain continue } } - remaining, err = src.hasMany(addrs) + remaining, gcb, err = src.hasMany(addrs, keeper) if err != nil { - return false, err + return false, gcb, err + } + if gcb != gcBehavior_Continue { + return false, gcb, nil } if !remaining { break } } - return remaining, nil + return remaining, gcBehavior_Continue, nil } -func (ts tableSet) get(ctx context.Context, h hash.Hash, stats *Stats) ([]byte, error) { +func (ts tableSet) get(ctx context.Context, h hash.Hash, keeper keeperF, stats *Stats) ([]byte, gcBehavior, error) { if err := ctx.Err(); err != nil { - return nil, err + return nil, gcBehavior_Continue, err } - f := func(css chunkSourceSet) ([]byte, error) { + f := func(css chunkSourceSet) ([]byte, gcBehavior, error) { for _, haver := range css { - data, err := haver.get(ctx, h, stats) - + data, gcb, err := haver.get(ctx, h, keeper, stats) if err != nil { - return nil, err + return nil, gcb, err + } + if gcb != gcBehavior_Continue { + return nil, gcb, nil } - if data != nil { - return data, nil + return data, gcBehavior_Continue, nil } } - - return nil, nil + return nil, gcBehavior_Continue, nil } - data, err := f(ts.novel) - + data, gcb, err := f(ts.novel) if err != nil { - return nil, err + return nil, gcb, err + } + if gcb != gcBehavior_Continue { + return nil, gcb, nil } - if data != nil { - return data, nil + return data, gcBehavior_Continue, nil } return f(ts.upstream) } -func (ts tableSet) getMany(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), stats *Stats) (remaining bool, err error) { - f := func(css chunkSourceSet) bool { +func (ts tableSet) getMany(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), keeper keeperF, stats *Stats) (bool, gcBehavior, error) { + f := func(css chunkSourceSet) (bool, gcBehavior, error) { for _, haver := range css { - remaining, err = haver.getMany(ctx, eg, reqs, found, stats) + remaining, gcb, err := haver.getMany(ctx, eg, reqs, found, keeper, stats) if err != nil { - return true + return true, gcb, err + } + if gcb != gcBehavior_Continue { + return true, gcb, nil } if !remaining { - return false + return false, gcb, nil } } - return true + return true, gcBehavior_Continue, nil + } + + remaining, gcb, err := f(ts.novel) + if err != nil { + return true, gcb, err + } + if gcb != gcBehavior_Continue { + return true, gcb, nil + } + if !remaining { + return false, gcBehavior_Continue, nil } - return f(ts.novel) && err == nil && f(ts.upstream), err + return f(ts.upstream) } -func (ts tableSet) getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), stats *Stats) (remaining bool, err error) { - f := func(css chunkSourceSet) bool { +func (ts tableSet) getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), keeper keeperF, stats *Stats) (bool, gcBehavior, error) { + f := func(css chunkSourceSet) (bool, gcBehavior, error) { for _, haver := range css { - remaining, err = haver.getManyCompressed(ctx, eg, reqs, found, stats) + remaining, gcb, err := haver.getManyCompressed(ctx, eg, reqs, found, keeper, stats) if err != nil { - return true + return true, gcb, err + } + if gcb != gcBehavior_Continue { + return true, gcb, nil } if !remaining { - return false + return false, gcBehavior_Continue, nil } } + return true, gcBehavior_Continue, nil + } - return true + remaining, gcb, err := f(ts.novel) + if err != nil { + return true, gcb, err + } + if gcb != gcBehavior_Continue { + return remaining, gcb, nil + } + if !remaining { + return false, gcBehavior_Continue, nil } - return f(ts.novel) && err == nil && f(ts.upstream), err + return f(ts.upstream) } func (ts tableSet) count() (uint32, error) { @@ -326,7 +364,7 @@ func (ts tableSet) Size() int { // append adds a memTable to an existing tableSet, compacting |mt| and // returning a new tableSet with newly compacted table added. -func (ts tableSet) append(ctx context.Context, mt *memTable, checker refCheck, hasCache *lru.TwoQueueCache[hash.Hash, struct{}], stats *Stats) (tableSet, error) { +func (ts tableSet) append(ctx context.Context, mt *memTable, checker refCheck, keeper keeperF, hasCache *lru.TwoQueueCache[hash.Hash, struct{}], stats *Stats) (tableSet, gcBehavior, error) { addrs := hash.NewHashSet() for _, getAddrs := range mt.getChildAddrs { getAddrs(ctx, addrs, func(h hash.Hash) bool { return hasCache.Contains(h) }) @@ -342,14 +380,17 @@ func (ts tableSet) append(ctx context.Context, mt *memTable, checker refCheck, h sort.Sort(hasRecordByPrefix(mt.pendingRefs)) absent, err := checker(mt.pendingRefs) if err != nil { - return tableSet{}, err + return tableSet{}, gcBehavior_Continue, err } else if absent.Size() > 0 { - return tableSet{}, fmt.Errorf("%w: found dangling references to %s", ErrDanglingRef, absent.String()) + return tableSet{}, gcBehavior_Continue, fmt.Errorf("%w: found dangling references to %s", ErrDanglingRef, absent.String()) } - cs, err := ts.p.Persist(ctx, mt, ts, stats) + cs, gcb, err := ts.p.Persist(ctx, mt, ts, keeper, stats) if err != nil { - return tableSet{}, err + return tableSet{}, gcBehavior_Continue, err + } + if gcb != gcBehavior_Continue { + return tableSet{}, gcb, nil } newTs := tableSet{ @@ -360,7 +401,7 @@ func (ts tableSet) append(ctx context.Context, mt *memTable, checker refCheck, h rl: ts.rl, } newTs.novel[cs.hash()] = cs - return newTs, nil + return newTs, gcBehavior_Continue, nil } // flatten returns a new tableSet with |upstream| set to the union of ts.novel @@ -500,11 +541,12 @@ func (ts tableSet) toSpecs() ([]tableSpec, error) { return tableSpecs, nil } -func tableSetCalcReads(ts tableSet, reqs []getRecord, blockSize uint64) (reads int, split, remaining bool, err error) { +func tableSetCalcReads(ts tableSet, reqs []getRecord, blockSize uint64, keeper keeperF) (reads int, split, remaining bool, gcb gcBehavior, err error) { all := copyChunkSourceSet(ts.upstream) for a, cs := range ts.novel { all[a] = cs } + gcb = gcBehavior_Continue for _, tbl := range all { rdr, ok := tbl.(*fileTableReader) if !ok { @@ -514,9 +556,12 @@ func tableSetCalcReads(ts tableSet, reqs []getRecord, blockSize uint64) (reads i var n int var more bool - n, more, err = rdr.calcReads(reqs, blockSize) + n, more, gcb, err = rdr.calcReads(reqs, blockSize, keeper) if err != nil { - return 0, false, false, err + return 0, false, false, gcb, err + } + if gcb != gcBehavior_Continue { + return 0, false, false, gcb, nil } reads += n diff --git a/go/store/nbs/table_set_test.go b/go/store/nbs/table_set_test.go index b7d54cbd092..e1cfcef3dad 100644 --- a/go/store/nbs/table_set_test.go +++ b/go/store/nbs/table_set_test.go @@ -41,7 +41,7 @@ var hasManyHasAll = func([]hasRecord) (hash.HashSet, error) { func TestTableSetPrependEmpty(t *testing.T) { hasCache, err := lru.New2Q[hash.Hash, struct{}](1024) require.NoError(t, err) - ts, err := newFakeTableSet(&UnlimitedQuotaProvider{}).append(context.Background(), newMemTable(testMemTableSize), hasManyHasAll, hasCache, &Stats{}) + ts, _, err := newFakeTableSet(&UnlimitedQuotaProvider{}).append(context.Background(), newMemTable(testMemTableSize), hasManyHasAll, nil, hasCache, &Stats{}) require.NoError(t, err) specs, err := ts.toSpecs() require.NoError(t, err) @@ -61,7 +61,7 @@ func TestTableSetPrepend(t *testing.T) { mt.addChunk(computeAddr(testChunks[0]), testChunks[0]) hasCache, err := lru.New2Q[hash.Hash, struct{}](1024) require.NoError(t, err) - ts, err = ts.append(context.Background(), mt, hasManyHasAll, hasCache, &Stats{}) + ts, _, err = ts.append(context.Background(), mt, hasManyHasAll, nil, hasCache, &Stats{}) require.NoError(t, err) firstSpecs, err := ts.toSpecs() @@ -71,7 +71,7 @@ func TestTableSetPrepend(t *testing.T) { mt = newMemTable(testMemTableSize) mt.addChunk(computeAddr(testChunks[1]), testChunks[1]) mt.addChunk(computeAddr(testChunks[2]), testChunks[2]) - ts, err = ts.append(context.Background(), mt, hasManyHasAll, hasCache, &Stats{}) + ts, _, err = ts.append(context.Background(), mt, hasManyHasAll, nil, hasCache, &Stats{}) require.NoError(t, err) secondSpecs, err := ts.toSpecs() @@ -93,17 +93,17 @@ func TestTableSetToSpecsExcludesEmptyTable(t *testing.T) { mt.addChunk(computeAddr(testChunks[0]), testChunks[0]) hasCache, err := lru.New2Q[hash.Hash, struct{}](1024) require.NoError(t, err) - ts, err = ts.append(context.Background(), mt, hasManyHasAll, hasCache, &Stats{}) + ts, _, err = ts.append(context.Background(), mt, hasManyHasAll, nil, hasCache, &Stats{}) require.NoError(t, err) mt = newMemTable(testMemTableSize) - ts, err = ts.append(context.Background(), mt, hasManyHasAll, hasCache, &Stats{}) + ts, _, err = ts.append(context.Background(), mt, hasManyHasAll, nil, hasCache, &Stats{}) require.NoError(t, err) mt = newMemTable(testMemTableSize) mt.addChunk(computeAddr(testChunks[1]), testChunks[1]) mt.addChunk(computeAddr(testChunks[2]), testChunks[2]) - ts, err = ts.append(context.Background(), mt, hasManyHasAll, hasCache, &Stats{}) + ts, _, err = ts.append(context.Background(), mt, hasManyHasAll, nil, hasCache, &Stats{}) require.NoError(t, err) specs, err = ts.toSpecs() @@ -124,17 +124,17 @@ func TestTableSetFlattenExcludesEmptyTable(t *testing.T) { mt.addChunk(computeAddr(testChunks[0]), testChunks[0]) hasCache, err := lru.New2Q[hash.Hash, struct{}](1024) require.NoError(t, err) - ts, err = ts.append(context.Background(), mt, hasManyHasAll, hasCache, &Stats{}) + ts, _, err = ts.append(context.Background(), mt, hasManyHasAll, nil, hasCache, &Stats{}) require.NoError(t, err) mt = newMemTable(testMemTableSize) - ts, err = ts.append(context.Background(), mt, hasManyHasAll, hasCache, &Stats{}) + ts, _, err = ts.append(context.Background(), mt, hasManyHasAll, nil, hasCache, &Stats{}) require.NoError(t, err) mt = newMemTable(testMemTableSize) mt.addChunk(computeAddr(testChunks[1]), testChunks[1]) mt.addChunk(computeAddr(testChunks[2]), testChunks[2]) - ts, err = ts.append(context.Background(), mt, hasManyHasAll, hasCache, &Stats{}) + ts, _, err = ts.append(context.Background(), mt, hasManyHasAll, nil, hasCache, &Stats{}) require.NoError(t, err) ts, err = ts.flatten(context.Background()) @@ -146,7 +146,7 @@ func persist(t *testing.T, p tablePersister, chunks ...[]byte) { for _, c := range chunks { mt := newMemTable(testMemTableSize) mt.addChunk(computeAddr(c), c) - cs, err := p.Persist(context.Background(), mt, nil, &Stats{}) + cs, _, err := p.Persist(context.Background(), mt, nil, nil, &Stats{}) require.NoError(t, err) require.NoError(t, cs.close()) } @@ -164,7 +164,7 @@ func TestTableSetRebase(t *testing.T) { for _, c := range chunks { mt := newMemTable(testMemTableSize) mt.addChunk(computeAddr(c), c) - ts, err = ts.append(context.Background(), mt, hasManyHasAll, hasCache, &Stats{}) + ts, _, err = ts.append(context.Background(), mt, hasManyHasAll, nil, hasCache, &Stats{}) require.NoError(t, err) } return ts @@ -213,13 +213,13 @@ func TestTableSetPhysicalLen(t *testing.T) { mt.addChunk(computeAddr(testChunks[0]), testChunks[0]) hasCache, err := lru.New2Q[hash.Hash, struct{}](1024) require.NoError(t, err) - ts, err = ts.append(context.Background(), mt, hasManyHasAll, hasCache, &Stats{}) + ts, _, err = ts.append(context.Background(), mt, hasManyHasAll, nil, hasCache, &Stats{}) require.NoError(t, err) mt = newMemTable(testMemTableSize) mt.addChunk(computeAddr(testChunks[1]), testChunks[1]) mt.addChunk(computeAddr(testChunks[2]), testChunks[2]) - ts, err = ts.append(context.Background(), mt, hasManyHasAll, hasCache, &Stats{}) + ts, _, err = ts.append(context.Background(), mt, hasManyHasAll, nil, hasCache, &Stats{}) require.NoError(t, err) assert.True(mustUint64(ts.physicalLen()) > indexSize(mustUint32(ts.count()))) diff --git a/go/store/nbs/table_test.go b/go/store/nbs/table_test.go index 596bebc6890..e62bfc1618e 100644 --- a/go/store/nbs/table_test.go +++ b/go/store/nbs/table_test.go @@ -62,7 +62,7 @@ func buildTable(chunks [][]byte) ([]byte, hash.Hash, error) { } func mustGetString(assert *assert.Assertions, ctx context.Context, tr tableReader, data []byte) string { - bytes, err := tr.get(ctx, computeAddr(data), &Stats{}) + bytes, _, err := tr.get(ctx, computeAddr(data), nil, &Stats{}) assert.NoError(err) return string(bytes) } @@ -106,13 +106,13 @@ func TestSimple(t *testing.T) { func assertChunksInReader(chunks [][]byte, r chunkReader, assert *assert.Assertions) { for _, c := range chunks { - assert.True(r.has(computeAddr(c))) + assert.True(r.has(computeAddr(c), nil)) } } func assertChunksNotInReader(chunks [][]byte, r chunkReader, assert *assert.Assertions) { for _, c := range chunks { - assert.False(r.has(computeAddr(c))) + assert.False(r.has(computeAddr(c), nil)) } } @@ -142,7 +142,7 @@ func TestHasMany(t *testing.T) { } sort.Sort(hasRecordByPrefix(hasAddrs)) - _, err = tr.hasMany(hasAddrs) + _, _, err = tr.hasMany(hasAddrs, nil) require.NoError(t, err) for _, ha := range hasAddrs { assert.True(ha.has, "Nothing for prefix %d", ha.prefix) @@ -192,7 +192,7 @@ func TestHasManySequentialPrefix(t *testing.T) { hasAddrs[0] = hasRecord{&addrs[1], addrs[1].Prefix(), 1, false} hasAddrs[1] = hasRecord{&addrs[2], addrs[2].Prefix(), 2, false} - _, err = tr.hasMany(hasAddrs) + _, _, err = tr.hasMany(hasAddrs, nil) require.NoError(t, err) for _, ha := range hasAddrs { @@ -246,7 +246,7 @@ func BenchmarkHasMany(b *testing.B) { b.Run("dense has many", func(b *testing.B) { var ok bool for i := 0; i < b.N; i++ { - ok, err = tr.hasMany(hrecs) + ok, _, err = tr.hasMany(hrecs, nil) } assert.False(b, ok) assert.NoError(b, err) @@ -254,7 +254,7 @@ func BenchmarkHasMany(b *testing.B) { b.Run("sparse has many", func(b *testing.B) { var ok bool for i := 0; i < b.N; i++ { - ok, err = tr.hasMany(sparse) + ok, _, err = tr.hasMany(sparse, nil) } assert.True(b, ok) assert.NoError(b, err) @@ -290,7 +290,7 @@ func TestGetMany(t *testing.T) { eg, ctx := errgroup.WithContext(context.Background()) got := make([]*chunks.Chunk, 0) - _, err = tr.getMany(ctx, eg, getBatch, func(ctx context.Context, c *chunks.Chunk) { got = append(got, c) }, &Stats{}) + _, _, err = tr.getMany(ctx, eg, getBatch, func(ctx context.Context, c *chunks.Chunk) { got = append(got, c) }, nil, &Stats{}) require.NoError(t, err) require.NoError(t, eg.Wait()) @@ -324,13 +324,13 @@ func TestCalcReads(t *testing.T) { gb2 := []getRecord{getBatch[0], getBatch[2]} sort.Sort(getRecordByPrefix(getBatch)) - reads, remaining, err := tr.calcReads(getBatch, 0) + reads, remaining, _, err := tr.calcReads(getBatch, 0, nil) require.NoError(t, err) assert.False(remaining) assert.Equal(1, reads) sort.Sort(getRecordByPrefix(gb2)) - reads, remaining, err = tr.calcReads(gb2, 0) + reads, remaining, _, err = tr.calcReads(gb2, 0, nil) require.NoError(t, err) assert.False(remaining) assert.Equal(2, reads) @@ -398,8 +398,8 @@ func Test65k(t *testing.T) { for i := 0; i < count; i++ { data := dataFn(i) h := computeAddr(data) - assert.True(tr.has(computeAddr(data))) - bytes, err := tr.get(context.Background(), h, &Stats{}) + assert.True(tr.has(computeAddr(data), nil)) + bytes, _, err := tr.get(context.Background(), h, nil, &Stats{}) require.NoError(t, err) assert.Equal(string(data), string(bytes)) } @@ -407,8 +407,8 @@ func Test65k(t *testing.T) { for i := count; i < count*2; i++ { data := dataFn(i) h := computeAddr(data) - assert.False(tr.has(computeAddr(data))) - bytes, err := tr.get(context.Background(), h, &Stats{}) + assert.False(tr.has(computeAddr(data), nil)) + bytes, _, err := tr.get(context.Background(), h, nil, &Stats{}) require.NoError(t, err) assert.NotEqual(string(data), string(bytes)) } @@ -461,7 +461,7 @@ func doTestNGetMany(t *testing.T, count int) { eg, ctx := errgroup.WithContext(context.Background()) got := make([]*chunks.Chunk, 0) - _, err = tr.getMany(ctx, eg, getBatch, func(ctx context.Context, c *chunks.Chunk) { got = append(got, c) }, &Stats{}) + _, _, err = tr.getMany(ctx, eg, getBatch, func(ctx context.Context, c *chunks.Chunk) { got = append(got, c) }, nil, &Stats{}) require.NoError(t, err) require.NoError(t, eg.Wait()) diff --git a/go/store/types/serial_message.go b/go/store/types/serial_message.go index 50ee98c3261..e5c4060f147 100644 --- a/go/store/types/serial_message.go +++ b/go/store/types/serial_message.go @@ -769,10 +769,10 @@ func (sm SerialMessage) WalkAddrs(nbf *NomsBinFormat, cb func(addr hash.Hash) er return err } } - case serial.TableSchemaFileID, serial.ForeignKeyCollectionFileID: + case serial.TableSchemaFileID, serial.ForeignKeyCollectionFileID, serial.TupleFileID: // no further references from these file types return nil - case serial.ProllyTreeNodeFileID, serial.AddressMapFileID, serial.MergeArtifactsFileID, serial.BlobFileID, serial.CommitClosureFileID: + case serial.ProllyTreeNodeFileID, serial.AddressMapFileID, serial.MergeArtifactsFileID, serial.BlobFileID, serial.CommitClosureFileID, serial.VectorIndexNodeFileID: return message.WalkAddresses(context.TODO(), serial.Message(sm), func(ctx context.Context, addr hash.Hash) error { return cb(addr) }) diff --git a/go/store/types/value_store.go b/go/store/types/value_store.go index 026a97bcd39..fb348da918f 100644 --- a/go/store/types/value_store.go +++ b/go/store/types/value_store.go @@ -591,7 +591,7 @@ func (lvs *ValueStore) GC(ctx context.Context, mode GCMode, oldGenRefs, newGenRe var oldGenHasMany chunks.HasManyFunc switch mode { case GCModeDefault: - oldGenHasMany = oldGen.HasMany + oldGenHasMany = gcs.OldGenGCFilter() chksMode = chunks.GCMode_Default case GCModeFull: oldGenHasMany = unfilteredHashFunc @@ -601,11 +601,11 @@ func (lvs *ValueStore) GC(ctx context.Context, mode GCMode, oldGenRefs, newGenRe } err := func() error { - err := collector.BeginGC(lvs.gcAddChunk) + err := collector.BeginGC(lvs.gcAddChunk, chksMode) if err != nil { return err } - defer collector.EndGC() + defer collector.EndGC(chksMode) var callCancelSafepoint bool if safepoint != nil { @@ -650,7 +650,7 @@ func (lvs *ValueStore) GC(ctx context.Context, mode GCMode, oldGenRefs, newGenRe } if mode == GCModeDefault { - oldGenHasMany = oldGen.HasMany + oldGenHasMany = gcs.OldGenGCFilter() } else { oldGenHasMany = newFileHasMany } @@ -685,11 +685,11 @@ func (lvs *ValueStore) GC(ctx context.Context, mode GCMode, oldGenRefs, newGenRe newGenRefs.InsertAll(oldGenRefs) err := func() error { - err := collector.BeginGC(lvs.gcAddChunk) + err := collector.BeginGC(lvs.gcAddChunk, chunks.GCMode_Full) if err != nil { return err } - defer collector.EndGC() + defer collector.EndGC(chunks.GCMode_Full) var callCancelSafepoint bool if safepoint != nil { diff --git a/go/store/valuefile/value_file_test.go b/go/store/valuefile/value_file_test.go index 2d410ad0800..318dde9f97d 100644 --- a/go/store/valuefile/value_file_test.go +++ b/go/store/valuefile/value_file_test.go @@ -53,7 +53,7 @@ func TestReadWriteValueFile(t *testing.T) { values = append(values, m) } - path := filepath.Join(os.TempDir(), "file.nvf") + path := filepath.Join(t.TempDir(), "file.nvf") err = WriteValueFile(ctx, path, store, values...) require.NoError(t, err) diff --git a/integration-tests/bats/replication.bats b/integration-tests/bats/replication.bats index 460ebdc3f62..d10420ab689 100644 --- a/integration-tests/bats/replication.bats +++ b/integration-tests/bats/replication.bats @@ -201,6 +201,48 @@ teardown() { [[ "$output" =~ "b1" ]] || false } +# When a replica pulls refs, the remote refs are compared with the local refs to identify which local refs +# need to be deleted. Branches, tags, and remotes all share the ref space and previous versions of Dolt could +# incorrectly map remote refs and local refs, resulting in local refs being incorrectly removed, until future +# runs of replica synchronization. +@test "replication: local tag refs are not deleted" { + # Configure repo1 to push changes on commit and create tag a1 + cd repo1 + dolt config --local --add sqlserver.global.dolt_replicate_to_remote remote1 + dolt sql -q "call dolt_tag('a1');" + + # Configure repo2 to pull changes on read + cd .. + dolt clone file://./rem1 repo2 + cd repo2 + dolt config --local --add sqlserver.global.dolt_read_replica_remote origin + dolt config --local --add sqlserver.global.dolt_replicate_all_heads 1 + run dolt sql -q "select tag_name from dolt_tags;" + [ "$status" -eq 0 ] + [[ "$output" =~ "| tag_name |" ]] || false + [[ "$output" =~ "| a1 |" ]] || false + + # Create branch new1 in repo1 – "new1" sorts after "main", but before "a1", and previous + # versions of Dolt had problems computing which local refs to delete in this case. + cd ../repo1 + dolt sql -q "call dolt_branch('new1');" + + # Confirm that tag a1 has not been deleted. Note that we need to check for this immediately after + # creating branch new1 (i.e. before looking at branches), because the bug in the previous versions + # of Dolt would only manifest in the next command, and would be fixed by later remote pulls. + cd ../repo2 + run dolt sql -q "select tag_name from dolt_tags;" + [ "$status" -eq 0 ] + [[ "$output" =~ "| tag_name |" ]] || false + [[ "$output" =~ "| a1 |" ]] || false + + # Try again to make sure the results are stable + run dolt sql -q "select tag_name from dolt_tags;" + [ "$status" -eq 0 ] + [[ "$output" =~ "| tag_name |" ]] || false + [[ "$output" =~ "| a1 |" ]] || false +} + @test "replication: pull branch delete current branch" { skip "broken by latest transaction changes" @@ -627,7 +669,6 @@ SQL } @test "replication: pull all heads pulls tags" { - dolt clone file://./rem1 repo2 cd repo2 dolt checkout -b new_feature diff --git a/integration-tests/bats/sql-server.bats b/integration-tests/bats/sql-server.bats index ce6373a6b30..2f06c100ff5 100644 --- a/integration-tests/bats/sql-server.bats +++ b/integration-tests/bats/sql-server.bats @@ -1996,4 +1996,44 @@ EOF run dolt --data-dir datadir1 sql-server --data-dir datadir2 [ $status -eq 1 ] [[ "$output" =~ "cannot specify both global --data-dir argument and --data-dir in sql-server config" ]] || false -} \ No newline at end of file +} + +# This is really a test of the dolt_Branches system table, but due to needing a server with multiple dirty branches +# it was easier to test it with a sql-server. +@test "sql-server: dirty branches listed properly in dolt_branches table" { + skiponwindows "Missing dependencies" + + cd repo1 + dolt checkout main + dolt branch br1 # Will be a clean commit, ahead of main. + dolt branch br2 # will be a dirty branch, on main. + dolt branch br3 # will be a dirty branch, on br1 + start_sql_server repo1 + + dolt --use-db "repo1" --branch br1 sql -q "CREATE TABLE tbl (i int primary key)" + dolt --use-db "repo1" --branch br1 sql -q "CALL DOLT_COMMIT('-Am', 'commit it')" + + dolt --use-db "repo1" --branch br2 sql -q "CREATE TABLE tbl (j int primary key)" + + # Fast forward br3 to br1, then make it dirty. + dolt --use-db "repo1" --branch br3 sql -q "CALL DOLT_MERGE('br1')" + dolt --use-db "repo1" --branch br3 sql -q "CREATE TABLE othertbl (k int primary key)" + + stop_sql_server 1 && sleep 0.5 + + run dolt sql -q "SELECT name,dirty FROM dolt_branches" + [ "$status" -eq 0 ] + [[ "$output" =~ "br1 | false" ]] || false + [[ "$output" =~ "br2 | true " ]] || false + [[ "$output" =~ "br3 | true" ]] || false + [[ "$output" =~ "main | false" ]] || false + + # Verify that the dolt_branches table show the same output, regardless of the checked out branch. + dolt checkout br1 + run dolt sql -q "SELECT name,dirty FROM dolt_branches" + [ "$status" -eq 0 ] + [[ "$output" =~ "br1 | false" ]] || false + [[ "$output" =~ "br2 | true " ]] || false + [[ "$output" =~ "br3 | true" ]] || false + [[ "$output" =~ "main | false" ]] || false +} diff --git a/integration-tests/bats/vector-index.bats b/integration-tests/bats/vector-index.bats index 5693b7d8628..8d233f60887 100644 --- a/integration-tests/bats/vector-index.bats +++ b/integration-tests/bats/vector-index.bats @@ -430,3 +430,14 @@ SQL [[ "$output" =~ "pk1" ]] || false [[ "${#lines[@]}" = "1" ]] || false } + +@test "vector-index: can GC" { + dolt sql <