diff --git a/cmd/grype-db/cli/commands/cache_restore.go b/cmd/grype-db/cli/commands/cache_restore.go index 00980ea3..bd320109 100644 --- a/cmd/grype-db/cli/commands/cache_restore.go +++ b/cmd/grype-db/cli/commands/cache_restore.go @@ -8,7 +8,9 @@ import ( "io" "os" "path/filepath" + "strings" + "github.com/scylladb/go-set/strset" "github.com/spf13/cobra" "github.com/spf13/pflag" "github.com/spf13/viper" @@ -21,8 +23,11 @@ import ( var _ options.Interface = &cacheRestoreConfig{} type cacheRestoreConfig struct { - Cache cacheRestoreCache `yaml:"cache" json:"cache" mapstructure:"cache"` - options.Store `yaml:"provider" json:"provider" mapstructure:"provider"` + Cache cacheRestoreCache `yaml:"cache" json:"cache" mapstructure:"cache"` + Provider struct { + options.Store `yaml:",inline" mapstructure:",squash"` + options.Selection `yaml:",inline" mapstructure:",squash"` + } `yaml:"provider" json:"provider" mapstructure:"provider"` } type cacheRestoreCache struct { @@ -31,14 +36,14 @@ type cacheRestoreCache struct { } func (o *cacheRestoreConfig) AddFlags(flags *pflag.FlagSet) { - options.AddAllFlags(flags, &o.Cache.CacheRestore, &o.Cache.CacheArchive, &o.Store) + options.AddAllFlags(flags, &o.Cache.CacheRestore, &o.Cache.CacheArchive, &o.Provider.Store, &o.Provider.Selection) } func (o *cacheRestoreConfig) BindFlags(flags *pflag.FlagSet, v *viper.Viper) error { if err := options.Bind(v, "cache.delete-existing", flags.Lookup("delete-existing")); err != nil { return err } - return options.BindAllFlags(flags, v, &o.Cache.CacheRestore, &o.Cache.CacheArchive, &o.Store) + return options.BindAllFlags(flags, v, &o.Cache.CacheRestore, &o.Cache.CacheArchive, &o.Provider.Store, &o.Provider.Selection) } func CacheRestore(app *application.Application) *cobra.Command { @@ -47,9 +52,11 @@ func CacheRestore(app *application.Application) *cobra.Command { CacheArchive: options.DefaultCacheArchive(), CacheRestore: options.DefaultCacheRestore(), }, - Store: options.DefaultStore(), } + cfg.Provider.Store = options.DefaultStore() + cfg.Provider.Selection = options.DefaultSelection() + cmd := &cobra.Command{ Use: "restore", Short: "restore provider cache from a backup archive", @@ -68,27 +75,41 @@ func CacheRestore(app *application.Application) *cobra.Command { } func cacheRestore(cfg cacheRestoreConfig) error { - if err := os.MkdirAll(cfg.Store.Root, 0755); err != nil { + providers := "all" + if len(cfg.Provider.Selection.IncludeFilter) > 0 { + providers = fmt.Sprintf("%s", cfg.Provider.IncludeFilter) + } + log.WithFields("providers", providers).Info("restoring provider state") + + if err := os.MkdirAll(cfg.Provider.Root, 0755); err != nil { return fmt.Errorf("failed to create provider root directory: %w", err) } - providerNames, err := readProviderNamesFromRoot(cfg.Store.Root) + allowableProviders := strset.New(cfg.Provider.IncludeFilter...) + restorableProviders, err := readProviderNamesFromTarGz(cfg.Cache.CacheArchive.Path) if err != nil { return err } - if cfg.Cache.DeleteExisting { - log.Info("deleting existing provider data") - for _, name := range providerNames { - if err := deleteProviderCache(cfg.Store.Root, name); err != nil { + selectedProviders := strset.New() + + for _, name := range restorableProviders { + if allowableProviders.Size() > 0 && !allowableProviders.Has(name) { + log.WithFields("provider", name).Trace("skipping...") + continue + } + + selectedProviders.Add(name) + + if cfg.Cache.DeleteExisting { + log.WithFields("provider", name).Info("deleting existing provider data") + if err := deleteProviderCache(cfg.Provider.Store.Root, name); err != nil { return fmt.Errorf("failed to delete provider cache: %w", err) } - } - } else { - for _, name := range providerNames { - dir := filepath.Join(cfg.Store.Root, name) + } else { + dir := filepath.Join(cfg.Provider.Store.Root, name) if _, err := os.Stat(dir); !errors.Is(err, os.ErrNotExist) { - log.WithFields("dir", dir).Debug("note: there is pre-existing provider data which could be overwritten by the restore operation") + log.WithFields("provider", name, "dir", dir).Debug("note: there is pre-existing provider data which could be overwritten by the restore operation") } } } @@ -104,7 +125,7 @@ func cacheRestore(cfg cacheRestoreConfig) error { if err != nil { return err } - err = os.Chdir(cfg.Store.Root) + err = os.Chdir(cfg.Provider.Store.Root) if err != nil { return err } @@ -114,7 +135,7 @@ func cacheRestore(cfg cacheRestoreConfig) error { } }(wd) - if err := extractTarGz(f); err != nil { + if err := extractTarGz(f, selectedProviders); err != nil { return fmt.Errorf("failed to extract cache archive: %w", err) } @@ -123,7 +144,55 @@ func cacheRestore(cfg cacheRestoreConfig) error { return nil } -func extractTarGz(reader io.Reader) error { +func getProviderNameFromPath(path string) string { + pathComponents := strings.Split(filepath.Clean(path), string(os.PathSeparator)) + + if len(pathComponents) > 0 { + return pathComponents[0] + } + + return "" +} + +func readProviderNamesFromTarGz(tarPath string) ([]string, error) { + f, err := os.Open(tarPath) + if err != nil { + return nil, fmt.Errorf("failed to open cache archive: %w", err) + } + + gr, err := gzip.NewReader(f) + if err != nil { + return nil, fmt.Errorf("failed to create gzip reader: %w", err) + } + + providers := strset.New() + + tr := tar.NewReader(gr) + + for { + header, err := tr.Next() + + if errors.Is(err, io.EOF) { + break + } + + if err != nil { + return nil, fmt.Errorf("failed to read tar header: %w", err) + } + + provider := getProviderNameFromPath(header.Name) + + if provider != "" { + providers.Add(provider) + } + } + + f.Close() + + return providers.List(), nil +} + +func extractTarGz(reader io.Reader, selectedProviders *strset.Set) error { gr, err := gzip.NewReader(reader) if err != nil { return fmt.Errorf("failed to create gzip reader: %w", err) @@ -142,6 +211,12 @@ func extractTarGz(reader io.Reader) error { return fmt.Errorf("failed to read tar header: %w", err) } + provider := getProviderNameFromPath(header.Name) + if !selectedProviders.Has(provider) { + log.WithFields("path", header.Name, "provider", provider).Trace("skipping...") + continue + } + log.WithFields("path", header.Name).Trace("extracting file") switch header.Typeflag {