Skip to content

Commit

Permalink
more support for consumer identity provider
Browse files Browse the repository at this point in the history
  • Loading branch information
mandelsoft committed Nov 3, 2023
1 parent 11106c9 commit 7158463
Show file tree
Hide file tree
Showing 9 changed files with 133 additions and 28 deletions.
9 changes: 8 additions & 1 deletion pkg/blobaccess/bpi/view.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,19 @@ type blobAccessView struct {
baseblob BlobAccessBase
}

var _ utils.Validatable = (*blobAccessView)(nil)
var (
_ utils.Validatable = (*blobAccessView)(nil)
_ utils.Unwrappable = (*blobAccessView)(nil)
)

func (b *blobAccessView) base() BlobAccessBase {
return b.baseblob
}

func (b *blobAccessView) Unwrap() interface{} {
return b.baseblob
}

func (b *blobAccessView) Close() error {
return b.View.Close()
}
Expand Down
20 changes: 11 additions & 9 deletions pkg/contexts/credentials/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@

package credentials

import (
"github.com/open-component-model/ocm/pkg/utils"
)

func GetProvidedConsumerId(obj interface{}, uctx ...UsageContext) ConsumerIdentity {
if p, ok := obj.(ConsumerIdentityProvider); ok {
return p.GetConsumerId(uctx...)
}
return nil
return utils.UnwrappingCall(obj, func(provider ConsumerIdentityProvider) ConsumerIdentity {
return provider.GetConsumerId()
})
}

func GetProvidedIdentityMatcher(obj interface{}, uctx ...UsageContext) string {
if p, ok := obj.(ConsumerIdentityProvider); ok {
return p.GetIdentityMatcher()
}
return ""
func GetProvidedIdentityMatcher(obj interface{}) string {
return utils.UnwrappingCall(obj, func(provider ConsumerIdentityProvider) string {
return provider.GetIdentityMatcher()
})
}
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ func (b *artifactHandler) StoreBlob(blob cpi.BlobAccess, artType, hint string, g
log.Debug("oci artifact handler with ocm access source",
generics.AppendedSlice[any](values, "sourcetype", m.Source().AccessSpec().GetType())...,
)
if ocimeth, ok := m.Source().Base().(ociartifact.AccessMethodImpl); !keep && ok {
if ocimeth, ok := m.Source().Unwrap().(ociartifact.AccessMethodImpl); !keep && ok {
art, _, err = ocimeth.GetArtifact()
if err != nil {
return nil, errors.Wrapf(err, "cannot access source artifact")
Expand Down
8 changes: 4 additions & 4 deletions pkg/contexts/ocm/cpi/accspeccpi/methodview.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/open-component-model/ocm/pkg/blobaccess"
"github.com/open-component-model/ocm/pkg/contexts/credentials"
"github.com/open-component-model/ocm/pkg/refmgmt"
"github.com/open-component-model/ocm/pkg/utils"
)

type DigestSource interface {
Expand All @@ -22,9 +23,8 @@ type DigestSource interface {
// into a managed method with multiple views. The original method
// object is closed once the last view is closed.
type AccessMethodView interface {
utils.Unwrappable
AccessMethod

Base() interface{}
}

// AccessMethodForImplementation wrap an access method implementation object
Expand Down Expand Up @@ -61,7 +61,7 @@ var (
_ credentials.ConsumerIdentityProvider = (*accessMethodView)(nil)
)

func (a *accessMethodView) Base() interface{} {
func (a *accessMethodView) Unwrap() interface{} {
return a.methodimpl
}

Expand Down Expand Up @@ -125,7 +125,7 @@ func BlobAccessForAccessMethod(m AccessMethod) (blobaccess.AnnotatedBlobAccess[A

func GetAccessMethodImplementation(m AccessMethod) interface{} {
if v, ok := m.(AccessMethodView); ok {
return v.Base()
return v.Unwrap()
}
return nil
}
26 changes: 16 additions & 10 deletions pkg/contexts/ocm/cpi/dummy.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,43 +112,49 @@ func (d *DummyComponentVersionAccess) GetReferenceByIndex(i int) (ComponentRefer
}

func (d *DummyComponentVersionAccess) AccessMethod(spec AccessSpec) (AccessMethod, error) {
panic("implement me")
if spec.IsLocal(d.Context) {
return nil, errors.ErrNotSupported("local access method")
}
return spec.AccessMethod(d)
}

func (d *DummyComponentVersionAccess) GetInexpensiveContentVersionIdentity(spec AccessSpec) string {
panic("implement me")
if spec.IsLocal(d.Context) {
return ""
}
return spec.GetInexpensiveContentVersionIdentity(d)
}

func (d *DummyComponentVersionAccess) Update() error {
panic("implement me")
return errors.ErrNotSupported("update")
}

func (d *DummyComponentVersionAccess) AddBlob(blob BlobAccess, arttype, refName string, global AccessSpec, opts ...BlobUploadOption) (AccessSpec, error) {
panic("implement me")
return nil, errors.ErrNotSupported("adding blobs")
}

func (d *DummyComponentVersionAccess) SetResourceBlob(meta *ResourceMeta, blob BlobAccess, refname string, global AccessSpec, opts ...BlobModificationOption) error {
panic("implement me")
return errors.ErrNotSupported("adding blobs")
}

func (d *DummyComponentVersionAccess) AdjustResourceAccess(meta *internal.ResourceMeta, acc compdesc.AccessSpec, opts ...ModificationOption) error {
panic("implement me")
return errors.ErrNotSupported("resource modification")
}

func (d *DummyComponentVersionAccess) SetResource(meta *ResourceMeta, spec compdesc.AccessSpec, opts ...ModificationOption) error {
panic("implement me")
return errors.ErrNotSupported("resource modification")
}

func (d *DummyComponentVersionAccess) SetResourceAccess(art ResourceAccess, modopts ...BlobModificationOption) error {
panic("implement me")
return errors.ErrNotSupported("resource modification")
}

func (d *DummyComponentVersionAccess) SetSourceBlob(meta *SourceMeta, blob BlobAccess, refname string, global AccessSpec) error {
panic("implement me")
return errors.ErrNotSupported("source modification")
}

func (d *DummyComponentVersionAccess) SetSource(meta *SourceMeta, spec compdesc.AccessSpec) error {
panic("implement me")
return errors.ErrNotSupported("source modification")
}

func (d *DummyComponentVersionAccess) SetSourceByAccess(art SourceAccess) error {
Expand Down
23 changes: 21 additions & 2 deletions pkg/contexts/ocm/cpi/view.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ type repositoryView struct {
var (
_ Repository = (*repositoryView)(nil)
_ credentials.ConsumerIdentityProvider = (*repositoryView)(nil)
_ utils.Unwrappable = (*repositoryView)(nil)
)

func GetRepositoryImplementation(n Repository) (RepositoryImpl, error) {
Expand Down Expand Up @@ -114,6 +115,10 @@ func NewRepository(impl RepositoryImpl, name ...string) Repository {
return resource.NewResource[Repository](impl, repositoryViewCreator, utils.OptionalDefaulted("OCM repo", name...), true)
}

func (r *repositoryView) Unwrap() interface{} {
return r.impl
}

func (r *repositoryView) GetConsumerId(uctx ...credentials.UsageContext) credentials.ConsumerIdentity {
return credentials.GetProvidedConsumerId(r.impl, uctx...)
}
Expand Down Expand Up @@ -231,7 +236,10 @@ type componentAccessView struct {
impl ComponentAccessImpl
}

var _ ComponentAccess = (*componentAccessView)(nil)
var (
_ ComponentAccess = (*componentAccessView)(nil)
_ utils.Unwrappable = (*componentAccessView)(nil)
)

func GetComponentAccessImplementation(n ComponentAccess) (ComponentAccessImpl, error) {
if v, ok := n.(*componentAccessView); ok {
Expand All @@ -251,6 +259,10 @@ func NewComponentAccess(impl ComponentAccessImpl, kind ...string) ComponentAcces
return resource.NewResource[ComponentAccess](impl, componentAccessViewCreator, fmt.Sprintf("%s %s", utils.OptionalDefaulted("component", kind...), impl.GetName()), true)
}

func (c *componentAccessView) Unwrap() interface{} {
return c.impl
}

func (c *componentAccessView) GetContext() Context {
return c.impl.GetContext()
}
Expand Down Expand Up @@ -590,7 +602,10 @@ type componentVersionAccessView struct {
impl ComponentVersionAccessImpl
}

var _ ComponentVersionAccess = (*componentVersionAccessView)(nil)
var (
_ ComponentVersionAccess = (*componentVersionAccessView)(nil)
_ utils.Unwrappable = (*componentVersionAccessView)(nil)
)

func GetComponentVersionAccessImplementation(n ComponentVersionAccess) (ComponentVersionAccessImpl, error) {
if v, ok := n.(*componentVersionAccessView); ok {
Expand All @@ -610,6 +625,10 @@ func NewComponentVersionAccess(impl ComponentVersionAccessImpl) ComponentVersion
return resource.NewResource[ComponentVersionAccess](impl, artifactAccessViewCreator, fmt.Sprintf("component version %s/%s", impl.GetName(), impl.GetVersion()), true)
}

func (c *componentVersionAccessView) Unwrap() interface{} {
return c.impl
}

func (c *componentVersionAccessView) Close() error {
err := c.Execute(func() error {
// executed under local lock, if refcount is one, I'm the last user.
Expand Down
21 changes: 21 additions & 0 deletions pkg/contexts/ocm/cpi/view_rsc.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"

"github.com/open-component-model/ocm/pkg/blobaccess"
"github.com/open-component-model/ocm/pkg/contexts/credentials"
"github.com/open-component-model/ocm/pkg/contexts/ocm/compdesc"
ocm "github.com/open-component-model/ocm/pkg/contexts/ocm/context"
"github.com/open-component-model/ocm/pkg/contexts/ocm/cpi/accspeccpi"
Expand Down Expand Up @@ -189,6 +190,8 @@ type artifactAccessProvider[M any] struct {
meta *M
}

var _ credentials.ConsumerIdentityProvider = (*artifactAccessProvider[any])(nil)

type artifactCVAccessProvider[M any] struct {
artifactAccessProvider[M]
componentVersionProvider
Expand All @@ -212,6 +215,24 @@ func (r *artifactAccessProvider[M]) Meta() *M {
return r.meta
}

func (b *artifactAccessProvider[M]) GetConsumerId(uctx ...credentials.UsageContext) credentials.ConsumerIdentity {
m, err := b.AccessMethod()
if err != nil {
return nil
}
defer m.Close()
return credentials.GetProvidedConsumerId(m, uctx...)
}

func (b *artifactAccessProvider[M]) GetIdentityMatcher() string {
m, err := b.AccessMethod()
if err != nil {
return ""
}
defer m.Close()
return credentials.GetProvidedIdentityMatcher(m)
}

////////////////////////////////////////////////////////////////////////////////

var _ ResourceAccess = (*artifactAccessProvider[ResourceMeta])(nil)
Expand Down
2 changes: 1 addition & 1 deletion pkg/contexts/ocm/internal/accesstypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ type AccessMethodImpl interface {

// AccessMethod is used to support independently closable
// views on an access method implementation, which can
// be passed arround and stored. The original method implementation
// be passed around and stored. The original method implementation
// object is closed once the last view is closed.
type AccessMethod interface {
refmgmt.Dup[AccessMethod]
Expand Down
50 changes: 50 additions & 0 deletions pkg/utils/unwrap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// SPDX-FileCopyrightText: 2023 SAP SE or an SAP affiliate company and Open Component Model contributors.
//
// SPDX-License-Identifier: Apache-2.0

package utils

type Unwrappable interface {
Unwrap() interface{}
}

func Unwrap(o interface{}) interface{} {
if o != nil {
if u, ok := o.(Unwrappable); ok {
return u.Unwrap()
}
}
return nil
}

func UnwrappingCast[I interface{}](o interface{}) I {
var _nil I

for o != nil {
if i, ok := o.(I); ok {
return i
}
if i := Unwrap(o); i != o {
o = i
} else {
o = nil
}
}
return _nil
}

func UnwrappingCall[R any, I any](o interface{}, f func(I) R) R {
var _nil R

for o != nil {
if i, ok := o.(I); ok {
return f(i)
}
if i := Unwrap(o); i != o {
o = i
} else {
o = nil
}
}
return _nil
}

0 comments on commit 7158463

Please sign in to comment.