diff --git a/internal/resources/providers/gcplib/inventory/map_cache.go b/internal/resources/providers/gcplib/inventory/map_cache.go new file mode 100644 index 0000000000..89c2fafe8f --- /dev/null +++ b/internal/resources/providers/gcplib/inventory/map_cache.go @@ -0,0 +1,46 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package inventory + +import ( + "sync" +) + +type MapCache[T any] struct { + results map[string]T + mu sync.Mutex +} + +func (c *MapCache[T]) Get(fn func() T, key string) T { + c.mu.Lock() + defer c.mu.Unlock() + + if value, ok := c.results[key]; ok { + return value + } + + value := fn() + c.results[key] = value + return value +} + +func NewMapCache[T any]() *MapCache[T] { + return &MapCache[T]{ + results: make(map[string]T), + } +} diff --git a/internal/resources/providers/gcplib/inventory/map_cache_test.go b/internal/resources/providers/gcplib/inventory/map_cache_test.go new file mode 100644 index 0000000000..5c76dd5626 --- /dev/null +++ b/internal/resources/providers/gcplib/inventory/map_cache_test.go @@ -0,0 +1,63 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package inventory + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +type MockFunction struct { + mock.Mock +} + +func (m *MockFunction) GetSomeValue() int { + m.Called() + return 0 +} + +func TestMapCacheGet(t *testing.T) { + cache := NewMapCache[int]() + + // Test getting existing value from cache + cache.results["key1"] = 42 + mockFunction := new(MockFunction) + result := cache.Get(mockFunction.GetSomeValue, "key1") + mockFunction.AssertNotCalled(t, "GetSomeValue") + assert.Equal(t, 42, result) + + // Test getting non-existing value from cache + mockFunction.On("GetSomeValue").Return(mockFunction.GetSomeValue()) + result = cache.Get(mockFunction.GetSomeValue, "key2") + mockFunction.AssertNumberOfCalls(t, "GetSomeValue", 2) // 1 by Return(), 2nd by cache.Get() + assert.Equal(t, 0, result) + + // Test concurrent accesses + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + cache.Get(func() int { return 1 }, "concurrent_key") + }() + } + wg.Wait() +} diff --git a/internal/resources/providers/gcplib/inventory/provider.go b/internal/resources/providers/gcplib/inventory/provider.go index e2eaa5a9c1..ed96288217 100644 --- a/internal/resources/providers/gcplib/inventory/provider.go +++ b/internal/resources/providers/gcplib/inventory/provider.go @@ -38,11 +38,11 @@ import ( ) type Provider struct { - log *logp.Logger - config auth.GcpFactoryConfig - inventory *AssetsInventoryWrapper - crm *ResourceManagerWrapper - crmCache map[string]*fetching.CloudAccountMetadata + log *logp.Logger + config auth.GcpFactoryConfig + inventory *AssetsInventoryWrapper + crm *ResourceManagerWrapper + cloudAccountMetadataCache *MapCache[*fetching.CloudAccountMetadata] } type AssetsInventoryWrapper struct { @@ -86,13 +86,6 @@ type ExtendedGcpAsset struct { type ProviderInitializer struct{} -type GcpAssetIDs struct { - orgId string - projectId string - parentProject string - parentOrg string -} - type dnsPolicyFields struct { networks []string enableLogging bool @@ -151,32 +144,38 @@ func (p *ProviderInitializer) Init(ctx context.Context, log *logp.Logger, gcpCon if err != nil { return nil, err } + + displayNamesCache := NewMapCache[string]() // wrap the resource manager client for mocking crmServiceWrapper := &ResourceManagerWrapper{ getProjectDisplayName: func(ctx context.Context, parent string) string { - prj, err := crmService.Projects.Get(parent).Context(ctx).Do() - if err != nil { - log.Errorf("error fetching GCP Project: %s, error: %s", parent, err) - return "" - } - return prj.DisplayName + return displayNamesCache.Get(func() string { + prj, err := crmService.Projects.Get(parent).Context(ctx).Do() + if err != nil { + log.Errorf("error fetching GCP Project: %s, error: %s", parent, err) + return "" + } + return prj.DisplayName + }, parent) }, getOrganizationDisplayName: func(ctx context.Context, parent string) string { - org, err := crmService.Organizations.Get(parent).Context(ctx).Do() - if err != nil { - log.Errorf("error fetching GCP Org: %s, error: %s", parent, err) - return "" - } - return org.DisplayName + return displayNamesCache.Get(func() string { + org, err := crmService.Organizations.Get(parent).Context(ctx).Do() + if err != nil { + log.Errorf("error fetching GCP Org: %s, error: %s", parent, err) + return "" + } + return org.DisplayName + }, parent) }, } return &Provider{ - config: gcpConfig, - log: log, - inventory: assetsInventoryWrapper, - crm: crmServiceWrapper, - crmCache: make(map[string]*fetching.CloudAccountMetadata), + config: gcpConfig, + log: log, + inventory: assetsInventoryWrapper, + crm: crmServiceWrapper, + cloudAccountMetadataCache: NewMapCache[*fetching.CloudAccountMetadata](), }, nil } @@ -209,7 +208,7 @@ func (p *Provider) ListAllAssetTypesByName(ctx context.Context, assetTypes []str var assets []*assetpb.Asset assets = append(append(assets, resourceAssets...), policyAssets...) mergedAssets := mergeAssetContentType(assets) - extendedAssets := extendWithECS(ctx, p.crm, p.crmCache, mergedAssets) + extendedAssets := p.extendWithCloudMetadata(ctx, mergedAssets) // Enrich network assets with dns policy p.enrichNetworkAssets(ctx, extendedAssets) @@ -441,22 +440,18 @@ func mergeAssetContentType(assets []*assetpb.Asset) []*assetpb.Asset { } // extends the assets with the project and organization display name -func extendWithECS(ctx context.Context, crm *ResourceManagerWrapper, cache map[string]*fetching.CloudAccountMetadata, assets []*assetpb.Asset) []*ExtendedGcpAsset { +func (p *Provider) extendWithCloudMetadata(ctx context.Context, assets []*assetpb.Asset) []*ExtendedGcpAsset { extendedAssets := make([]*ExtendedGcpAsset, 0, len(assets)) for _, asset := range assets { - keys := getAssetIds(asset) - cacheKey := fmt.Sprintf("%s/%s", keys.parentProject, keys.parentOrg) - if cloudAccount, ok := cache[cacheKey]; ok { - extendedAssets = append(extendedAssets, &ExtendedGcpAsset{ - Asset: asset, - CloudAccount: cloudAccount, - }) - continue - } - cache[cacheKey] = getCloudAccountMetadata(ctx, crm, keys) + orgId := getOrganizationId(asset.Ancestors) + projectId := getProjectId(asset.Ancestors) + cacheKey := fmt.Sprintf("%s/%s", projectId, orgId) + cloudAccount := p.cloudAccountMetadataCache.Get(func() *fetching.CloudAccountMetadata { + return p.getCloudAccountMetadata(ctx, projectId, orgId) + }, cacheKey) extendedAssets = append(extendedAssets, &ExtendedGcpAsset{ Asset: asset, - CloudAccount: cache[cacheKey], + CloudAccount: cloudAccount, }) } return extendedAssets @@ -469,71 +464,57 @@ func (p *Provider) ListProjectsAncestorsPolicies(ctx context.Context) ([]*Projec AssetTypes: []string{CrmProjectAssetType}, }) p.log.Infof("Listed %d GCP projects", len(projects)) - ancestorsPolicies := map[string][]*ExtendedGcpAsset{} + ancestorsPoliciesCache := NewMapCache[[]*ExtendedGcpAsset]() return lo.Map(projects, func(project *assetpb.Asset, _ int) *ProjectPoliciesAsset { - projectAsset := extendWithECS(ctx, p.crm, p.crmCache, []*assetpb.Asset{project})[0] + projectAsset := p.extendWithCloudMetadata(ctx, []*assetpb.Asset{project})[0] // Skip first ancestor it as we already got it - policiesAssets := append([]*ExtendedGcpAsset{projectAsset}, getAncestorsAssets(ctx, ancestorsPolicies, p, project.Ancestors[1:])...) + policiesAssets := append([]*ExtendedGcpAsset{projectAsset}, getAncestorsAssets(ctx, ancestorsPoliciesCache, p, project.Ancestors[1:])...) return &ProjectPoliciesAsset{CloudAccount: projectAsset.CloudAccount, Policies: policiesAssets} }), nil } -func getAncestorsAssets(ctx context.Context, ancestorsPolicies map[string][]*ExtendedGcpAsset, p *Provider, ancestors []string) []*ExtendedGcpAsset { +func getAncestorsAssets(ctx context.Context, ancestorsPoliciesCache *MapCache[[]*ExtendedGcpAsset], p *Provider, ancestors []string) []*ExtendedGcpAsset { return lo.Flatten(lo.Map(ancestors, func(parent string, _ int) []*ExtendedGcpAsset { - if ancestorsPolicies[parent] != nil { - return ancestorsPolicies[parent] - } - var assetType string - if strings.HasPrefix(parent, "folders") { - assetType = CrmFolderAssetType - } - if strings.HasPrefix(parent, "organizations") { - assetType = CrmOrgAssetType - } - - assets := p.getAllAssets(ctx, &assetpb.ListAssetsRequest{ - ContentType: assetpb.ContentType_IAM_POLICY, - Parent: parent, - AssetTypes: []string{assetType}, - }) - extendedAssets := extendWithECS(ctx, p.crm, p.crmCache, assets) - ancestorsPolicies[parent] = extendedAssets + extendedAssets := ancestorsPoliciesCache.Get(func() []*ExtendedGcpAsset { + var assetType string + if strings.HasPrefix(parent, "folders") { + assetType = CrmFolderAssetType + } + if strings.HasPrefix(parent, "organizations") { + assetType = CrmOrgAssetType + } + return p.extendWithCloudMetadata(ctx, p.getAllAssets(ctx, &assetpb.ListAssetsRequest{ + ContentType: assetpb.ContentType_IAM_POLICY, + Parent: parent, + AssetTypes: []string{assetType}, + })) + }, parent) return extendedAssets })) } -func getAssetIds(asset *assetpb.Asset) GcpAssetIDs { - orgId := getOrganizationId(asset.Ancestors) - projectId := getProjectId(asset.Ancestors) - parentProject := fmt.Sprintf("projects/%s", projectId) - parentOrg := fmt.Sprintf("organizations/%s", orgId) - return GcpAssetIDs{ - orgId: orgId, - projectId: projectId, - parentProject: parentProject, - parentOrg: parentOrg, - } -} - -func getCloudAccountMetadata(ctx context.Context, crm *ResourceManagerWrapper, keys GcpAssetIDs) *fetching.CloudAccountMetadata { +func (p *Provider) getCloudAccountMetadata(ctx context.Context, projectId string, orgId string) *fetching.CloudAccountMetadata { var orgName string var projectName string wg := sync.WaitGroup{} wg.Add(1) go func() { - orgName = crm.getOrganizationDisplayName(ctx, keys.parentOrg) + orgName = p.crm.getOrganizationDisplayName(ctx, fmt.Sprintf("organizations/%s", orgId)) wg.Done() }() wg.Add(1) go func() { - projectName = crm.getProjectDisplayName(ctx, keys.parentProject) + // some assets are not associated with a project + if projectId != "" { + projectName = p.crm.getProjectDisplayName(ctx, fmt.Sprintf("projects/%s", projectId)) + } wg.Done() }() wg.Wait() return &fetching.CloudAccountMetadata{ - AccountId: keys.projectId, + AccountId: projectId, AccountName: projectName, - OrganisationId: keys.orgId, + OrganisationId: orgId, OrganizationName: orgName, } } diff --git a/internal/resources/providers/gcplib/inventory/provider_test.go b/internal/resources/providers/gcplib/inventory/provider_test.go index e4b2cb9734..9675acfefe 100644 --- a/internal/resources/providers/gcplib/inventory/provider_test.go +++ b/internal/resources/providers/gcplib/inventory/provider_test.go @@ -89,7 +89,7 @@ func (s *ProviderTestSuite) TestListAllAssetTypesByName() { return "OrganizationName" }, }, - crmCache: make(map[string]*fetching.CloudAccountMetadata), + cloudAccountMetadataCache: NewMapCache[*fetching.CloudAccountMetadata](), } s.mockedIterator.On("Next").Return(&assetpb.Asset{Name: "AssetName1", Resource: &assetpb.Resource{}, Ancestors: []string{"projects/1", "organizations/1"}}, nil).Once() @@ -138,7 +138,7 @@ func (s *ProviderTestSuite) TestListMonitoringAssets() { return "OrganizationName1" }, }, - crmCache: make(map[string]*fetching.CloudAccountMetadata), + cloudAccountMetadataCache: NewMapCache[*fetching.CloudAccountMetadata](), } expected := []*MonitoringAsset{ @@ -218,7 +218,7 @@ func (s *ProviderTestSuite) TestEnrichNetworkAssets() { return "OrganizationName" }, }, - crmCache: make(map[string]*fetching.CloudAccountMetadata), + cloudAccountMetadataCache: NewMapCache[*fetching.CloudAccountMetadata](), } assets := []*ExtendedGcpAsset{ @@ -332,7 +332,7 @@ func (s *ProviderTestSuite) TestListServiceUsageAssets() { return "OrganizationName1" }, }, - crmCache: make(map[string]*fetching.CloudAccountMetadata), + cloudAccountMetadataCache: NewMapCache[*fetching.CloudAccountMetadata](), } // asset's resource @@ -421,7 +421,7 @@ func (s *ProviderTestSuite) TestListLoggingAssets() { return "OrganizationName1" }, }, - crmCache: make(map[string]*fetching.CloudAccountMetadata), + cloudAccountMetadataCache: NewMapCache[*fetching.CloudAccountMetadata](), } // asset's resource @@ -460,7 +460,7 @@ func (s *ProviderTestSuite) TestListProjectsAncestorsPolicies() { return "OrganizationName" }, }, - crmCache: make(map[string]*fetching.CloudAccountMetadata), + cloudAccountMetadataCache: NewMapCache[*fetching.CloudAccountMetadata](), } s.mockedIterator.On("Next").Return(&assetpb.Asset{Name: "AssetName1", IamPolicy: &iampb.Policy{}, Ancestors: []string{"projects/1", "organizations/1"}}, nil).Once()