Skip to content

Commit

Permalink
feat: allow specifying providers for cache restore (#126)
Browse files Browse the repository at this point in the history
Allow specifying providers to restore the cache for.  All of the other
`grype-db cache` commands already supported this.

Signed-off-by: Weston Steimel <[email protected]>
  • Loading branch information
westonsteimel authored Jul 11, 2023
1 parent cd5d91c commit d102ad1
Showing 1 changed file with 94 additions and 19 deletions.
113 changes: 94 additions & 19 deletions cmd/grype-db/cli/commands/cache_restore.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ import (
"io"
"os"
"path/filepath"
"strings"

"github.com/scylladb/go-set/strset"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/spf13/viper"
Expand All @@ -21,8 +23,11 @@ import (
var _ options.Interface = &cacheRestoreConfig{}

type cacheRestoreConfig struct {
Cache cacheRestoreCache `yaml:"cache" json:"cache" mapstructure:"cache"`
options.Store `yaml:"provider" json:"provider" mapstructure:"provider"`
Cache cacheRestoreCache `yaml:"cache" json:"cache" mapstructure:"cache"`
Provider struct {
options.Store `yaml:",inline" mapstructure:",squash"`
options.Selection `yaml:",inline" mapstructure:",squash"`
} `yaml:"provider" json:"provider" mapstructure:"provider"`
}

type cacheRestoreCache struct {
Expand All @@ -31,14 +36,14 @@ type cacheRestoreCache struct {
}

func (o *cacheRestoreConfig) AddFlags(flags *pflag.FlagSet) {
options.AddAllFlags(flags, &o.Cache.CacheRestore, &o.Cache.CacheArchive, &o.Store)
options.AddAllFlags(flags, &o.Cache.CacheRestore, &o.Cache.CacheArchive, &o.Provider.Store, &o.Provider.Selection)
}

func (o *cacheRestoreConfig) BindFlags(flags *pflag.FlagSet, v *viper.Viper) error {
if err := options.Bind(v, "cache.delete-existing", flags.Lookup("delete-existing")); err != nil {
return err
}
return options.BindAllFlags(flags, v, &o.Cache.CacheRestore, &o.Cache.CacheArchive, &o.Store)
return options.BindAllFlags(flags, v, &o.Cache.CacheRestore, &o.Cache.CacheArchive, &o.Provider.Store, &o.Provider.Selection)
}

func CacheRestore(app *application.Application) *cobra.Command {
Expand All @@ -47,9 +52,11 @@ func CacheRestore(app *application.Application) *cobra.Command {
CacheArchive: options.DefaultCacheArchive(),
CacheRestore: options.DefaultCacheRestore(),
},
Store: options.DefaultStore(),
}

cfg.Provider.Store = options.DefaultStore()
cfg.Provider.Selection = options.DefaultSelection()

cmd := &cobra.Command{
Use: "restore",
Short: "restore provider cache from a backup archive",
Expand All @@ -68,27 +75,41 @@ func CacheRestore(app *application.Application) *cobra.Command {
}

func cacheRestore(cfg cacheRestoreConfig) error {
if err := os.MkdirAll(cfg.Store.Root, 0755); err != nil {
providers := "all"
if len(cfg.Provider.Selection.IncludeFilter) > 0 {
providers = fmt.Sprintf("%s", cfg.Provider.IncludeFilter)
}
log.WithFields("providers", providers).Info("restoring provider state")

if err := os.MkdirAll(cfg.Provider.Root, 0755); err != nil {
return fmt.Errorf("failed to create provider root directory: %w", err)
}

providerNames, err := readProviderNamesFromRoot(cfg.Store.Root)
allowableProviders := strset.New(cfg.Provider.IncludeFilter...)
restorableProviders, err := readProviderNamesFromTarGz(cfg.Cache.CacheArchive.Path)
if err != nil {
return err
}

if cfg.Cache.DeleteExisting {
log.Info("deleting existing provider data")
for _, name := range providerNames {
if err := deleteProviderCache(cfg.Store.Root, name); err != nil {
selectedProviders := strset.New()

for _, name := range restorableProviders {
if allowableProviders.Size() > 0 && !allowableProviders.Has(name) {
log.WithFields("provider", name).Trace("skipping...")
continue
}

selectedProviders.Add(name)

if cfg.Cache.DeleteExisting {
log.WithFields("provider", name).Info("deleting existing provider data")
if err := deleteProviderCache(cfg.Provider.Store.Root, name); err != nil {
return fmt.Errorf("failed to delete provider cache: %w", err)
}
}
} else {
for _, name := range providerNames {
dir := filepath.Join(cfg.Store.Root, name)
} else {
dir := filepath.Join(cfg.Provider.Store.Root, name)
if _, err := os.Stat(dir); !errors.Is(err, os.ErrNotExist) {
log.WithFields("dir", dir).Debug("note: there is pre-existing provider data which could be overwritten by the restore operation")
log.WithFields("provider", name, "dir", dir).Debug("note: there is pre-existing provider data which could be overwritten by the restore operation")
}
}
}
Expand All @@ -104,7 +125,7 @@ func cacheRestore(cfg cacheRestoreConfig) error {
if err != nil {
return err
}
err = os.Chdir(cfg.Store.Root)
err = os.Chdir(cfg.Provider.Store.Root)
if err != nil {
return err
}
Expand All @@ -114,7 +135,7 @@ func cacheRestore(cfg cacheRestoreConfig) error {
}
}(wd)

if err := extractTarGz(f); err != nil {
if err := extractTarGz(f, selectedProviders); err != nil {
return fmt.Errorf("failed to extract cache archive: %w", err)
}

Expand All @@ -123,7 +144,55 @@ func cacheRestore(cfg cacheRestoreConfig) error {
return nil
}

func extractTarGz(reader io.Reader) error {
func getProviderNameFromPath(path string) string {
pathComponents := strings.Split(filepath.Clean(path), string(os.PathSeparator))

if len(pathComponents) > 0 {
return pathComponents[0]
}

return ""
}

func readProviderNamesFromTarGz(tarPath string) ([]string, error) {
f, err := os.Open(tarPath)
if err != nil {
return nil, fmt.Errorf("failed to open cache archive: %w", err)
}

gr, err := gzip.NewReader(f)
if err != nil {
return nil, fmt.Errorf("failed to create gzip reader: %w", err)
}

providers := strset.New()

tr := tar.NewReader(gr)

for {
header, err := tr.Next()

if errors.Is(err, io.EOF) {
break
}

if err != nil {
return nil, fmt.Errorf("failed to read tar header: %w", err)
}

provider := getProviderNameFromPath(header.Name)

if provider != "" {
providers.Add(provider)
}
}

f.Close()

return providers.List(), nil
}

func extractTarGz(reader io.Reader, selectedProviders *strset.Set) error {
gr, err := gzip.NewReader(reader)
if err != nil {
return fmt.Errorf("failed to create gzip reader: %w", err)
Expand All @@ -142,6 +211,12 @@ func extractTarGz(reader io.Reader) error {
return fmt.Errorf("failed to read tar header: %w", err)
}

provider := getProviderNameFromPath(header.Name)
if !selectedProviders.Has(provider) {
log.WithFields("path", header.Name, "provider", provider).Trace("skipping...")
continue
}

log.WithFields("path", header.Name).Trace("extracting file")

switch header.Typeflag {
Expand Down

0 comments on commit d102ad1

Please sign in to comment.