Skip to content

Commit

Permalink
add cache
Browse files Browse the repository at this point in the history
  • Loading branch information
orouz committed Apr 8, 2024
1 parent 89184aa commit f1aa8f7
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 88 deletions.
46 changes: 46 additions & 0 deletions internal/resources/providers/gcplib/inventory/map_cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Licensed to Elasticsearch B.V. under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Elasticsearch B.V. licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package inventory

import (
"sync"
)

type MapCache[T any] struct {
results map[string]T
mu sync.Mutex
}

func (c *MapCache[T]) Get(fn func() T, key string) T {
c.mu.Lock()
defer c.mu.Unlock()

if value, ok := c.results[key]; ok {
return value
}

value := fn()
c.results[key] = value
return value
}

func NewMapCache[T any]() *MapCache[T] {
return &MapCache[T]{
results: make(map[string]T),
}
}
63 changes: 63 additions & 0 deletions internal/resources/providers/gcplib/inventory/map_cache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Licensed to Elasticsearch B.V. under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Elasticsearch B.V. licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package inventory

import (
"sync"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)

type MockFunction struct {
mock.Mock
}

func (m *MockFunction) GetSomeValue() int {
m.Called()
return 0
}

func TestMapCacheGet(t *testing.T) {
cache := NewMapCache[int]()

// Test getting existing value from cache
cache.results["key1"] = 42
mockFunction := new(MockFunction)
result := cache.Get(mockFunction.GetSomeValue, "key1")
mockFunction.AssertNotCalled(t, "GetSomeValue")
assert.Equal(t, 42, result)

// Test getting non-existing value from cache
mockFunction.On("GetSomeValue").Return(mockFunction.GetSomeValue())
result = cache.Get(mockFunction.GetSomeValue, "key2")
mockFunction.AssertNumberOfCalls(t, "GetSomeValue", 2) // 1 by Return(), 2nd by cache.Get()
assert.Equal(t, 0, result)

// Test concurrent accesses
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
cache.Get(func() int { return 1 }, "concurrent_key")
}()
}
wg.Wait()
}
145 changes: 63 additions & 82 deletions internal/resources/providers/gcplib/inventory/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ import (
)

type Provider struct {
log *logp.Logger
config auth.GcpFactoryConfig
inventory *AssetsInventoryWrapper
crm *ResourceManagerWrapper
crmCache map[string]*fetching.CloudAccountMetadata
log *logp.Logger
config auth.GcpFactoryConfig
inventory *AssetsInventoryWrapper
crm *ResourceManagerWrapper
cloudAccountMetadataCache *MapCache[*fetching.CloudAccountMetadata]
}

type AssetsInventoryWrapper struct {
Expand Down Expand Up @@ -86,13 +86,6 @@ type ExtendedGcpAsset struct {

type ProviderInitializer struct{}

type GcpAssetIDs struct {
orgId string
projectId string
parentProject string
parentOrg string
}

type dnsPolicyFields struct {
networks []string
enableLogging bool
Expand Down Expand Up @@ -151,32 +144,38 @@ func (p *ProviderInitializer) Init(ctx context.Context, log *logp.Logger, gcpCon
if err != nil {
return nil, err
}

displayNamesCache := NewMapCache[string]()
// wrap the resource manager client for mocking
crmServiceWrapper := &ResourceManagerWrapper{
getProjectDisplayName: func(ctx context.Context, parent string) string {
prj, err := crmService.Projects.Get(parent).Context(ctx).Do()
if err != nil {
log.Errorf("error fetching GCP Project: %s, error: %s", parent, err)
return ""
}
return prj.DisplayName
return displayNamesCache.Get(func() string {
prj, err := crmService.Projects.Get(parent).Context(ctx).Do()
if err != nil {
log.Errorf("error fetching GCP Project: %s, error: %s", parent, err)
return ""
}
return prj.DisplayName
}, parent)
},
getOrganizationDisplayName: func(ctx context.Context, parent string) string {
org, err := crmService.Organizations.Get(parent).Context(ctx).Do()
if err != nil {
log.Errorf("error fetching GCP Org: %s, error: %s", parent, err)
return ""
}
return org.DisplayName
return displayNamesCache.Get(func() string {
org, err := crmService.Organizations.Get(parent).Context(ctx).Do()
if err != nil {
log.Errorf("error fetching GCP Org: %s, error: %s", parent, err)
return ""
}
return org.DisplayName
}, parent)
},
}

return &Provider{
config: gcpConfig,
log: log,
inventory: assetsInventoryWrapper,
crm: crmServiceWrapper,
crmCache: make(map[string]*fetching.CloudAccountMetadata),
config: gcpConfig,
log: log,
inventory: assetsInventoryWrapper,
crm: crmServiceWrapper,
cloudAccountMetadataCache: NewMapCache[*fetching.CloudAccountMetadata](),
}, nil
}

Expand Down Expand Up @@ -209,7 +208,7 @@ func (p *Provider) ListAllAssetTypesByName(ctx context.Context, assetTypes []str
var assets []*assetpb.Asset
assets = append(append(assets, resourceAssets...), policyAssets...)
mergedAssets := mergeAssetContentType(assets)
extendedAssets := extendWithECS(ctx, p.crm, p.crmCache, mergedAssets)
extendedAssets := p.extendWithCloudMetadata(ctx, mergedAssets)
// Enrich network assets with dns policy
p.enrichNetworkAssets(ctx, extendedAssets)

Expand Down Expand Up @@ -441,22 +440,18 @@ func mergeAssetContentType(assets []*assetpb.Asset) []*assetpb.Asset {
}

// extends the assets with the project and organization display name
func extendWithECS(ctx context.Context, crm *ResourceManagerWrapper, cache map[string]*fetching.CloudAccountMetadata, assets []*assetpb.Asset) []*ExtendedGcpAsset {
func (p *Provider) extendWithCloudMetadata(ctx context.Context, assets []*assetpb.Asset) []*ExtendedGcpAsset {
extendedAssets := make([]*ExtendedGcpAsset, 0, len(assets))
for _, asset := range assets {
keys := getAssetIds(asset)
cacheKey := fmt.Sprintf("%s/%s", keys.parentProject, keys.parentOrg)
if cloudAccount, ok := cache[cacheKey]; ok {
extendedAssets = append(extendedAssets, &ExtendedGcpAsset{
Asset: asset,
CloudAccount: cloudAccount,
})
continue
}
cache[cacheKey] = getCloudAccountMetadata(ctx, crm, keys)
orgId := getOrganizationId(asset.Ancestors)
projectId := getProjectId(asset.Ancestors)
cacheKey := fmt.Sprintf("%s/%s", projectId, orgId)
cloudAccount := p.cloudAccountMetadataCache.Get(func() *fetching.CloudAccountMetadata {
return p.getCloudAccountMetadata(ctx, projectId, orgId)
}, cacheKey)
extendedAssets = append(extendedAssets, &ExtendedGcpAsset{
Asset: asset,
CloudAccount: cache[cacheKey],
CloudAccount: cloudAccount,
})
}
return extendedAssets
Expand All @@ -469,71 +464,57 @@ func (p *Provider) ListProjectsAncestorsPolicies(ctx context.Context) ([]*Projec
AssetTypes: []string{CrmProjectAssetType},
})
p.log.Infof("Listed %d GCP projects", len(projects))
ancestorsPolicies := map[string][]*ExtendedGcpAsset{}
ancestorsPoliciesCache := NewMapCache[[]*ExtendedGcpAsset]()
return lo.Map(projects, func(project *assetpb.Asset, _ int) *ProjectPoliciesAsset {
projectAsset := extendWithECS(ctx, p.crm, p.crmCache, []*assetpb.Asset{project})[0]
projectAsset := p.extendWithCloudMetadata(ctx, []*assetpb.Asset{project})[0]
// Skip first ancestor it as we already got it
policiesAssets := append([]*ExtendedGcpAsset{projectAsset}, getAncestorsAssets(ctx, ancestorsPolicies, p, project.Ancestors[1:])...)
policiesAssets := append([]*ExtendedGcpAsset{projectAsset}, getAncestorsAssets(ctx, ancestorsPoliciesCache, p, project.Ancestors[1:])...)
return &ProjectPoliciesAsset{CloudAccount: projectAsset.CloudAccount, Policies: policiesAssets}
}), nil
}

func getAncestorsAssets(ctx context.Context, ancestorsPolicies map[string][]*ExtendedGcpAsset, p *Provider, ancestors []string) []*ExtendedGcpAsset {
func getAncestorsAssets(ctx context.Context, ancestorsPoliciesCache *MapCache[[]*ExtendedGcpAsset], p *Provider, ancestors []string) []*ExtendedGcpAsset {
return lo.Flatten(lo.Map(ancestors, func(parent string, _ int) []*ExtendedGcpAsset {
if ancestorsPolicies[parent] != nil {
return ancestorsPolicies[parent]
}
var assetType string
if strings.HasPrefix(parent, "folders") {
assetType = CrmFolderAssetType
}
if strings.HasPrefix(parent, "organizations") {
assetType = CrmOrgAssetType
}

assets := p.getAllAssets(ctx, &assetpb.ListAssetsRequest{
ContentType: assetpb.ContentType_IAM_POLICY,
Parent: parent,
AssetTypes: []string{assetType},
})
extendedAssets := extendWithECS(ctx, p.crm, p.crmCache, assets)
ancestorsPolicies[parent] = extendedAssets
extendedAssets := ancestorsPoliciesCache.Get(func() []*ExtendedGcpAsset {
var assetType string
if strings.HasPrefix(parent, "folders") {
assetType = CrmFolderAssetType
}
if strings.HasPrefix(parent, "organizations") {
assetType = CrmOrgAssetType
}
return p.extendWithCloudMetadata(ctx, p.getAllAssets(ctx, &assetpb.ListAssetsRequest{
ContentType: assetpb.ContentType_IAM_POLICY,
Parent: parent,
AssetTypes: []string{assetType},
}))
}, parent)
return extendedAssets
}))
}

func getAssetIds(asset *assetpb.Asset) GcpAssetIDs {
orgId := getOrganizationId(asset.Ancestors)
projectId := getProjectId(asset.Ancestors)
parentProject := fmt.Sprintf("projects/%s", projectId)
parentOrg := fmt.Sprintf("organizations/%s", orgId)
return GcpAssetIDs{
orgId: orgId,
projectId: projectId,
parentProject: parentProject,
parentOrg: parentOrg,
}
}

func getCloudAccountMetadata(ctx context.Context, crm *ResourceManagerWrapper, keys GcpAssetIDs) *fetching.CloudAccountMetadata {
func (p *Provider) getCloudAccountMetadata(ctx context.Context, projectId string, orgId string) *fetching.CloudAccountMetadata {
var orgName string
var projectName string
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
orgName = crm.getOrganizationDisplayName(ctx, keys.parentOrg)
orgName = p.crm.getOrganizationDisplayName(ctx, fmt.Sprintf("organizations/%s", orgId))
wg.Done()
}()
wg.Add(1)
go func() {
projectName = crm.getProjectDisplayName(ctx, keys.parentProject)
// some assets are not associated with a project
if projectId != "" {
projectName = p.crm.getProjectDisplayName(ctx, fmt.Sprintf("projects/%s", projectId))
}
wg.Done()
}()
wg.Wait()
return &fetching.CloudAccountMetadata{
AccountId: keys.projectId,
AccountId: projectId,
AccountName: projectName,
OrganisationId: keys.orgId,
OrganisationId: orgId,
OrganizationName: orgName,
}
}
Expand Down
12 changes: 6 additions & 6 deletions internal/resources/providers/gcplib/inventory/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func (s *ProviderTestSuite) TestListAllAssetTypesByName() {
return "OrganizationName"
},
},
crmCache: make(map[string]*fetching.CloudAccountMetadata),
cloudAccountMetadataCache: NewMapCache[*fetching.CloudAccountMetadata](),
}

s.mockedIterator.On("Next").Return(&assetpb.Asset{Name: "AssetName1", Resource: &assetpb.Resource{}, Ancestors: []string{"projects/1", "organizations/1"}}, nil).Once()
Expand Down Expand Up @@ -138,7 +138,7 @@ func (s *ProviderTestSuite) TestListMonitoringAssets() {
return "OrganizationName1"
},
},
crmCache: make(map[string]*fetching.CloudAccountMetadata),
cloudAccountMetadataCache: NewMapCache[*fetching.CloudAccountMetadata](),
}

expected := []*MonitoringAsset{
Expand Down Expand Up @@ -218,7 +218,7 @@ func (s *ProviderTestSuite) TestEnrichNetworkAssets() {
return "OrganizationName"
},
},
crmCache: make(map[string]*fetching.CloudAccountMetadata),
cloudAccountMetadataCache: NewMapCache[*fetching.CloudAccountMetadata](),
}

assets := []*ExtendedGcpAsset{
Expand Down Expand Up @@ -332,7 +332,7 @@ func (s *ProviderTestSuite) TestListServiceUsageAssets() {
return "OrganizationName1"
},
},
crmCache: make(map[string]*fetching.CloudAccountMetadata),
cloudAccountMetadataCache: NewMapCache[*fetching.CloudAccountMetadata](),
}

// asset's resource
Expand Down Expand Up @@ -421,7 +421,7 @@ func (s *ProviderTestSuite) TestListLoggingAssets() {
return "OrganizationName1"
},
},
crmCache: make(map[string]*fetching.CloudAccountMetadata),
cloudAccountMetadataCache: NewMapCache[*fetching.CloudAccountMetadata](),
}

// asset's resource
Expand Down Expand Up @@ -460,7 +460,7 @@ func (s *ProviderTestSuite) TestListProjectsAncestorsPolicies() {
return "OrganizationName"
},
},
crmCache: make(map[string]*fetching.CloudAccountMetadata),
cloudAccountMetadataCache: NewMapCache[*fetching.CloudAccountMetadata](),
}

s.mockedIterator.On("Next").Return(&assetpb.Asset{Name: "AssetName1", IamPolicy: &iampb.Policy{}, Ancestors: []string{"projects/1", "organizations/1"}}, nil).Once()
Expand Down

0 comments on commit f1aa8f7

Please sign in to comment.