diff --git a/internal/jimm/jimm.go b/internal/jimm/jimm.go index af8edac89..b4eaa1bd0 100644 --- a/internal/jimm/jimm.go +++ b/internal/jimm/jimm.go @@ -195,8 +195,9 @@ type SSHManager interface { // PublicKeyHandler is the method to verify the public key of the user. It returns a user if successful. PublicKeyHandler(ctx context.Context, claimUser string, key []byte) (*openfga.User, error) - // ResolveAddressesFromModelUUID is the method to resolve the address of the controller to contact given the model UUID. - ResolveAddressesFromModelUUID(ctx context.Context, modelUUID string) ([]string, error) + // ControllerInfoFromModelUUID is the method to resolve the address of the controller to contact given the model UUID and + // a valid JWT To connect to the controller. + ControllerInfoFromModelUUID(ctx context.Context, modelUUID string, user *openfga.User) (ssh.ControllerInfo, error) } // JujuManager is the interface to manage all Juju related operations. @@ -420,7 +421,7 @@ func New(p Parameters) (*JIMM, error) { } j.sshKeyManager = sshKeyManager - sshManager, err := ssh.NewSSHManager(j.identityManager, j.jujuManager, j.sshKeyManager) + sshManager, err := ssh.NewSSHManager(j.identityManager, j.jujuManager, j.sshKeyManager, j.jujuAuthFactory) if err != nil { return nil, err } diff --git a/internal/jimm/ssh/ssh.go b/internal/jimm/ssh/ssh.go index 87060d823..be4db0acf 100644 --- a/internal/jimm/ssh/ssh.go +++ b/internal/jimm/ssh/ssh.go @@ -10,10 +10,19 @@ import ( "github.com/canonical/jimm/v3/internal/dbmodel" "github.com/canonical/jimm/v3/internal/errors" + "github.com/canonical/jimm/v3/internal/jimm/jujuauth" "github.com/canonical/jimm/v3/internal/openfga" "github.com/canonical/jimm/v3/internal/rpc" ) +// ControllerInfo is the struct holding the infomation to contact a controller +type ControllerInfo struct { + // addresses to dial the controller + Addresses []string + // JWT to authenticate to the controller + JWT string +} + // IdentityManager provides a means to fetch an identity from the identity service. type IdentityManager interface { FetchIdentity(ctx context.Context, id string) (*openfga.User, error) @@ -29,8 +38,13 @@ type SSHKeyManager interface { VerifyPublicKey(ctx context.Context, claimUser string, publicKey []byte) (bool, error) } +// JWTGeneratorFactory provides a means to create a token generator. +type JWTGeneratorFactory interface { + New() jujuauth.TokenGenerator +} + // NewSSHManager returns a new SSHManager that offers jimm functionality to the SSHJumpServer. -func NewSSHManager(identityManager IdentityManager, modelManager ModelManager, sshKeyManager SSHKeyManager) (*sshManager, error) { +func NewSSHManager(identityManager IdentityManager, modelManager ModelManager, sshKeyManager SSHKeyManager, jwtFactory JWTGeneratorFactory) (*sshManager, error) { if identityManager == nil { return nil, errors.E("identityManager cannot be nil") } @@ -40,10 +54,14 @@ func NewSSHManager(identityManager IdentityManager, modelManager ModelManager, s if sshKeyManager == nil { return nil, errors.E("sshManager cannot be nil") } + if jwtFactory == nil { + return nil, errors.E("jwtFactory cannot be nil") + } return &sshManager{ modelManager: modelManager, identityManager: identityManager, sshKeyManager: sshKeyManager, + jwtFactory: jwtFactory, }, nil } @@ -52,6 +70,7 @@ type sshManager struct { modelManager ModelManager identityManager IdentityManager sshKeyManager SSHKeyManager + jwtFactory JWTGeneratorFactory } // PublicKeyHandler is the method to verify the public key of the user. It first checks for the public key and then fetches the user. @@ -69,19 +88,27 @@ func (s *sshManager) PublicKeyHandler(ctx context.Context, claimUser string, key return user, nil } -// ResolveAddressesFromModelUUID is the method to resolve the address of the controller to contact given the model UUID. -func (s *sshManager) ResolveAddressesFromModelUUID(ctx context.Context, modelUUID string) ([]string, error) { - zapctx.Info(ctx, "ResolveAddressesFromModelUUID") - +// ControllerInfoFromModelUUID is the method to resolve the address of the controller to contact given the model UUID and +// a valid JWT To connect to the controller. +func (s *sshManager) ControllerInfoFromModelUUID(ctx context.Context, modelUUID string, user *openfga.User) (ControllerInfo, error) { + zapctx.Info(ctx, "ControllerInfoFromModelUUID") model, err := s.modelManager.GetModel(ctx, modelUUID) if err != nil { - return nil, errors.E(err, "cannot find model") + return ControllerInfo{}, errors.E(err, "cannot find model") } - addrs, _ := rpc.GetAddressesAndTLSConfig(ctx, &model.Controller) if len(addrs) == 0 { - return nil, errors.E(err, "cannot find addresses for model's controller") + return ControllerInfo{}, errors.E(err, "cannot find addresses for model's controller") + } + jwtGenerator := s.jwtFactory.New() + jwtGenerator.SetTags(model.ResourceTag(), model.Controller.ResourceTag()) + jwt, err := jwtGenerator.MakeLoginToken(ctx, user) + if err != nil { + return ControllerInfo{}, errors.E(err, "cannot generate jwt") } - return addrs, nil + return ControllerInfo{ + Addresses: addrs, + JWT: string(jwt), + }, nil } diff --git a/internal/jimm/ssh/ssh_test.go b/internal/jimm/ssh/ssh_test.go index fdfc52444..39c7215ce 100644 --- a/internal/jimm/ssh/ssh_test.go +++ b/internal/jimm/ssh/ssh_test.go @@ -3,17 +3,174 @@ package ssh_test import ( + "context" + "crypto/rand" + "crypto/rsa" + "database/sql" "testing" + "time" qt "github.com/frankban/quicktest" + "github.com/frankban/quicktest/qtsuite" + "github.com/juju/names/v5" + gossh "golang.org/x/crypto/ssh" + "github.com/canonical/jimm/v3/internal/db" + "github.com/canonical/jimm/v3/internal/dbmodel" + "github.com/canonical/jimm/v3/internal/jimm" + "github.com/canonical/jimm/v3/internal/jimm/identity" + "github.com/canonical/jimm/v3/internal/jimm/jujuauth" + "github.com/canonical/jimm/v3/internal/jimm/permissions" "github.com/canonical/jimm/v3/internal/jimm/ssh" + "github.com/canonical/jimm/v3/internal/jimm/sshkeys" + "github.com/canonical/jimm/v3/internal/jimmjwx" + "github.com/canonical/jimm/v3/internal/openfga" + "github.com/canonical/jimm/v3/internal/testutils/jimmtest" "github.com/canonical/jimm/v3/internal/testutils/jimmtest/mocks" ) -func TestSSHManagerCreation(t *testing.T) { - c := qt.New(t) - // TODO(simonedutto): add proper testing when implementing the sshkeymanager VerifyPublicKey method. - _, err := ssh.NewSSHManager(&mocks.IdentityManager{}, &mocks.ModelManager{}, &mocks.SSHKeyManager{}) +type sshManagerSuite struct { + publicKey sshkeys.PublicKey + allowedModelUUID string + allowedControllerUUID string + + sshManager jimm.SSHManager + + userWithAccess *openfga.User + userWithoutAccess *openfga.User +} + +const testSSHManagerEnv = ` +cloud-credentials: +- name: test-cred + cloud: test + owner: alice@canonical.com + type: empty +clouds: +- name: test + type: test + regions: + - name: test-region +controllers: +- name: test + uuid: 00000001-0000-0000-0000-000000000001 + cloud: test + region: test-region + public-address: localhost + +models: +- name: test-1 + uuid: 00000002-0000-0000-0000-000000000001 + owner: alice@canonical.com + cloud: test + region: test-region + cloud-credential: test-cred + controller: test + users: + - user: alice@canonical.com + access: admin +users: +- username: alice@canonical.com + controller-access: superuser +` + +func (s *sshManagerSuite) Init(c *qt.C) { + ctx := context.Background() + uuid := "00000002-0000-0000-0000-000000000001" + jimmTag := names.NewControllerTag(uuid) + // Setup DB + db := &db.Database{ + DB: jimmtest.PostgresDB(c, time.Now), + } + err := db.Migrate(context.Background()) + c.Assert(err, qt.IsNil) + // Setup OFGA + ofgaClient, _, _, err := jimmtest.SetupTestOFGAClient(c.Name()) + c.Assert(err, qt.IsNil) + + identityManager, err := identity.NewIdentityManager(db, ofgaClient) + c.Assert(err, qt.IsNil) + + // this is a mock non-mock model manager, bandaid until we have a real model manager to avoid creating a whole jimm. + modelManager := mocks.ModelManager{ + GetModel_: func(ctx context.Context, uuid string) (dbmodel.Model, error) { + m := dbmodel.Model{ + UUID: sql.NullString{ + String: uuid, + Valid: true, + }, + } + err := db.GetModel(ctx, &m) + return m, err + }, + } + permissionManager, err := permissions.NewManager(db, ofgaClient, uuid, jimmTag) + c.Assert(err, qt.IsNil) + jwtFactory := jujuauth.NewFactory(db, mocks.JWTService{ + NewJWT_: func(ctx context.Context, j jimmjwx.JWTParams) ([]byte, error) { + return []byte("jwt"), nil + }, + }, permissionManager) + + sshKeyManager, err := sshkeys.NewSSHKeyManager(db) + c.Assert(err, qt.IsNil) + + s.sshManager, err = ssh.NewSSHManager(identityManager, &modelManager, sshKeyManager, jwtFactory) + c.Assert(err, qt.IsNil) + env := jimmtest.ParseEnvironment(c, testSSHManagerEnv) + env.PopulateDB(c, db) + env.PopulateDBAndPermissions(c, jimmTag, db, ofgaClient) + // create a user and set permission for one model + s.userWithAccess, err = identityManager.FetchIdentity(ctx, env.Users[0].Username) + c.Assert(err, qt.IsNil) + s.allowedModelUUID = env.Models[0].UUID + s.allowedControllerUUID = env.Controllers[0].UUID + + // create a user without access + i2, err := dbmodel.NewIdentity("bob") c.Assert(err, qt.IsNil) + c.Assert(db.DB.Create(i2).Error, qt.IsNil) + s.userWithoutAccess = openfga.NewUser(i2, ofgaClient) + // setup public key + key, err := rsa.GenerateKey(rand.Reader, 2048) + c.Assert(err, qt.IsNil) + + pubKey, err := gossh.NewPublicKey(&key.PublicKey) + c.Assert(err, qt.IsNil) + s.publicKey = sshkeys.PublicKey{PublicKey: pubKey, Comment: "myComment"} + + c.Assert(err, qt.IsNil) + err = sshKeyManager.AddUserPublicKey(ctx, s.userWithAccess, s.publicKey) + c.Assert(err, qt.IsNil) +} + +func (s *sshManagerSuite) TestPublicKeyHandler(c *qt.C) { + ctx := context.Background() + + // Test that the PublicKeyHandler returns the correct user when the public key is valid. + user, err := s.sshManager.PublicKeyHandler(ctx, s.userWithAccess.Name, s.publicKey.Marshal()) + c.Assert(err, qt.IsNil) + c.Assert(user.Identity.Name, qt.Equals, "alice@canonical.com") + + // Test that the PublicKeyHandler returns an error when the public key is invalid. + _, err = s.sshManager.PublicKeyHandler(ctx, s.userWithoutAccess.Name, s.publicKey.Marshal()) + c.Assert(err, qt.ErrorMatches, `cannot verify key for user`) +} + +func (s *sshManagerSuite) TestControllerInfoFromModelUUID(c *qt.C) { + ctx := context.Background() + + // Test that the ControllerInfoFromModelUUID returns the correct controller address and user when the model UUID is valid. + connInfo, err := s.sshManager.ControllerInfoFromModelUUID(ctx, s.allowedModelUUID, s.userWithAccess) + c.Assert(err, qt.IsNil) + c.Assert(connInfo.Addresses, qt.HasLen, 1) + c.Assert(connInfo.JWT, qt.Not(qt.HasLen), 0) + + // Test that the ControllerInfoFromModelUUID returns an error when the model UUID is invalid. + _, err = s.sshManager.ControllerInfoFromModelUUID(ctx, "not-valid", s.userWithAccess) + c.Assert(err, qt.ErrorMatches, ".*cannot find model.*") +} + +func TestSSHManager(t *testing.T) { + qtsuite.Run(qt.New(t), &sshManagerSuite{}) } diff --git a/internal/ssh/dial.go b/internal/ssh/dial.go index b4ed01739..eb5da59df 100644 --- a/internal/ssh/dial.go +++ b/internal/ssh/dial.go @@ -11,21 +11,22 @@ import ( gossh "golang.org/x/crypto/ssh" "github.com/canonical/jimm/v3/internal/errors" + jimmssh "github.com/canonical/jimm/v3/internal/jimm/ssh" ) // dialControllerSSHServer dials the controller ssh server, trying the addresses sequentially and returning a go ssh client. -func dialControllerSSHServer(addrs []string, destPort uint32) (*gossh.Client, error) { +func dialControllerSSHServer(connInfo jimmssh.ControllerInfo, destPort uint32) (*gossh.Client, error) { var client *gossh.Client var err error var errs []error - for _, addr := range addrs { + for _, addr := range connInfo.Addresses { dest := net.JoinHostPort(addr, fmt.Sprint(destPort)) client, err = gossh.Dial("tcp", dest, &gossh.ClientConfig{ //nolint:gosec // this will be removed once we handle hostkeys HostKeyCallback: gossh.InsecureIgnoreHostKey(), Auth: []gossh.AuthMethod{ gossh.PasswordCallback(func() (secret string, err error) { - return "jwt", nil + return connInfo.JWT, nil }), }, Timeout: 5 * time.Second, diff --git a/internal/ssh/ssh.go b/internal/ssh/ssh.go index c861c3d2f..47226e9f1 100644 --- a/internal/ssh/ssh.go +++ b/internal/ssh/ssh.go @@ -15,6 +15,7 @@ import ( "go.uber.org/zap" gossh "golang.org/x/crypto/ssh" + jimmssh "github.com/canonical/jimm/v3/internal/jimm/ssh" "github.com/canonical/jimm/v3/internal/openfga" ) @@ -31,8 +32,9 @@ type SSHManager interface { // PublicKeyHandler is the method to verify the public key of the user. It returns a user if successful. PublicKeyHandler(ctx context.Context, claimUser string, key []byte) (*openfga.User, error) - // ResolveAddressesFromModelUUID is the method to resolve the address of the controller to contact given the model UUID. - ResolveAddressesFromModelUUID(ctx context.Context, modelUUID string) ([]string, error) + // ControllerInfoFromModelUUID is the method to resolve the address of the controller to contact given the model UUID and + // a valid JWT To connect to the controller. + ControllerInfoFromModelUUID(ctx context.Context, modelUUID string, user *openfga.User) (jimmssh.ControllerInfo, error) } // forwardMessage is the struct holding the information about the jump message received by the ssh client. @@ -137,18 +139,17 @@ func directTCPIPHandler(sshManager SSHManager) func(srv *ssh.Server, conn *gossh return } modelTag := names.NewModelTag(d.DestAddr) - // user is now ignored, but it will be needed for the jwt auth next-up. - _, err := fetchAndAuthorizeUser(ctx, modelTag) + user, err := fetchAndAuthorizeUser(ctx, modelTag) if err != nil { rejectConnectionAndLogError(ctx, newChan, err.Error(), err) return } - addrs, err := sshManager.ResolveAddressesFromModelUUID(ctx, modelTag.Id()) + connInfo, err := sshManager.ControllerInfoFromModelUUID(ctx, modelTag.Id(), user) if err != nil { - rejectConnectionAndLogError(ctx, newChan, "failed to resolve address from model uuid", err) + rejectConnectionAndLogError(ctx, newChan, "failed to get connection info", err) return } - client, err := dialControllerSSHServer(addrs, d.DestPort) + client, err := dialControllerSSHServer(connInfo, d.DestPort) if err != nil { rejectConnectionAndLogError(ctx, newChan, fmt.Sprintf("failed to dial controller ssh: %v", err), err) return diff --git a/internal/ssh/ssh_test.go b/internal/ssh/ssh_test.go index 8a09054a8..237772949 100644 --- a/internal/ssh/ssh_test.go +++ b/internal/ssh/ssh_test.go @@ -22,6 +22,7 @@ import ( "github.com/canonical/jimm/v3/internal/db" "github.com/canonical/jimm/v3/internal/dbmodel" + jimmssh "github.com/canonical/jimm/v3/internal/jimm/ssh" "github.com/canonical/jimm/v3/internal/openfga" ofganames "github.com/canonical/jimm/v3/internal/openfga/names" "github.com/canonical/jimm/v3/internal/ssh" @@ -88,6 +89,9 @@ func (s *sshSuite) Init(c *qt.C) { s.received <- true }, }, + PasswordHandler: func(ctx gliderssh.Context, password string) bool { + return "valid-jwt" == password + }, } go func() { _ = s.destinationJujuSSHServer.ListenAndServe() @@ -123,8 +127,11 @@ func (s *sshSuite) Init(c *qt.C) { } return userWithoutAccess, nil }, - ResolveAddressesFromModelUUID_: func(ctx context.Context, modelUUID string) ([]string, error) { - return []string{""}, nil + ControllerInfoFromModelUUID_: func(ctx context.Context, modelUUID string, user *openfga.User) (jimmssh.ControllerInfo, error) { + if user == userWithAccess { + return jimmssh.ControllerInfo{Addresses: []string{""}, JWT: "valid-jwt"}, nil + } + return jimmssh.ControllerInfo{Addresses: []string{""}, JWT: ""}, nil }, }) c.Assert(err, qt.IsNil) @@ -168,19 +175,12 @@ func (s *sshSuite) TestSSHJump(c *qt.C) { defer client.Close() // send forward message - msg := ssh.ForwardMessage{ - DestAddr: s.allowedModelUUID, - //nolint:gosec - DestPort: uint32(s.destinationServerPort), - SrcAddr: "localhost", - SrcPort: 0, - } s.testInDestinationServerF = func(fm ssh.ForwardMessage) { c.Check(fm.DestAddr, qt.Equals, s.allowedModelUUID) } - ch, _, err := client.OpenChannel("direct-tcpip", gossh.Marshal(&msg)) + conn, err := client.Dial("tcp", fmt.Sprintf("%s:%d", s.allowedModelUUID, s.destinationServerPort)) c.Check(err, qt.IsNil) - defer ch.Close() + defer conn.Close() select { case <-s.received: case <-time.After(100 * time.Millisecond): @@ -189,47 +189,48 @@ func (s *sshSuite) TestSSHJump(c *qt.C) { } func (s *sshSuite) TestSSHJumpPermissionFail(c *qt.C) { - client, err := gossh.Dial("tcp", fmt.Sprintf(":%d", s.jumpServerPort), &gossh.ClientConfig{ - HostKeyCallback: gossh.FixedHostKey(s.hostKey.PublicKey()), - Auth: []gossh.AuthMethod{ - gossh.PublicKeys(s.privateKey), + tests := []struct { + name string + user string + destAddr string + errMsg string + }{ + { + name: "alice not allowed on this model", + user: "alice", + destAddr: "982b16d9-a945-4762-b684-fd4fd885aa11", + errMsg: "ssh: rejected: connect failed (user doesn't have permission)", + }, + { + name: "bob not allowed on this model", + user: "bob", + destAddr: s.allowedModelUUID, + errMsg: "ssh: rejected: connect failed (user doesn't have permission)", + }, + { + name: "not existing user", + user: "mark", + destAddr: s.allowedModelUUID, + errMsg: "ssh: rejected: connect failed (user doesn't have permission)", }, - User: "alice", - }) - c.Assert(err, qt.IsNil) - defer client.Close() - - // send forward message - msg := ssh.ForwardMessage{ - DestAddr: "982b16d9-a945-4762-b684-fd4fd885aa11", - //nolint:gosec - DestPort: uint32(s.destinationServerPort), - SrcAddr: "localhost", - SrcPort: 0, } - _, _, err = client.OpenChannel("direct-tcpip", gossh.Marshal(&msg)) - c.Assert(err, qt.ErrorMatches, ".*user doesn't have permission.*") - client, err = gossh.Dial("tcp", fmt.Sprintf(":%d", s.jumpServerPort), &gossh.ClientConfig{ - //nolint:gosec // this will be removed once we handle hostkeys - HostKeyCallback: gossh.InsecureIgnoreHostKey(), - Auth: []gossh.AuthMethod{ - gossh.PublicKeys(s.privateKey), - }, - User: "bob", - }) - c.Assert(err, qt.IsNil) - defer client.Close() - // send forward message - msg = ssh.ForwardMessage{ - DestAddr: s.allowedModelUUID, - //nolint:gosec - DestPort: uint32(s.destinationServerPort), - SrcAddr: "localhost", - SrcPort: 0, + for _, test := range tests { + c.Run(test.name, func(c *qt.C) { + client, err := gossh.Dial("tcp", fmt.Sprintf(":%d", s.jumpServerPort), &gossh.ClientConfig{ + HostKeyCallback: gossh.FixedHostKey(s.hostKey.PublicKey()), + Auth: []gossh.AuthMethod{ + gossh.PublicKeys(s.privateKey), + }, + User: test.user, + }) + c.Assert(err, qt.IsNil) + defer client.Close() + + _, err = client.Dial("tcp", fmt.Sprintf("%s:%d", test.destAddr, s.destinationServerPort)) + c.Assert(err.Error(), qt.Equals, test.errMsg) + }) } - _, _, err = client.OpenChannel("direct-tcpip", gossh.Marshal(&msg)) - c.Assert(err, qt.ErrorMatches, ".*user doesn't have permission.*") } func (s *sshSuite) TestSSHJumpDialFail(c *qt.C) { @@ -252,19 +253,10 @@ func (s *sshSuite) TestSSHFinalDestinationDialFail(c *qt.C) { User: "alice", }) c.Assert(err, qt.IsNil) - - // send forward message - msg := ssh.ForwardMessage{ - DestAddr: "model1", - //nolint:gosec - DestPort: 1, // the test fails because there is no ssh server on this port. - SrcAddr: "localhost", - SrcPort: 0, - } s.testInDestinationServerF = func(fm ssh.ForwardMessage) { c.Check(fm.DestAddr, qt.Equals, "model1") } - _, _, err = client.OpenChannel("direct-tcpip", gossh.Marshal(&msg)) + _, err = client.Dial("tcp", fmt.Sprintf("%s:%d", "model1", 1)) c.Assert(err, qt.ErrorMatches, ".*connect failed.*") } diff --git a/internal/testutils/jimmtest/env.go b/internal/testutils/jimmtest/env.go index 3db83517a..03a5e3a67 100644 --- a/internal/testutils/jimmtest/env.go +++ b/internal/testutils/jimmtest/env.go @@ -407,13 +407,14 @@ func (cc *CloudCredential) DBObject(c Tester, db *db.Database) dbmodel.CloudCred // A Controller represents the definition of a controller in a test // environment. type Controller struct { - Name string `json:"name"` - UUID string `json:"uuid"` - Cloud string `json:"cloud"` - CloudRegion string `json:"region"` - CloudRegions []CloudRegionControllerPriority `json:"cloud-regions"` - AgentVersion string `json:"agent-version"` - Deprecated bool `json:"deprecated,omitempty"` + Name string `json:"name"` + UUID string `json:"uuid"` + Cloud string `json:"cloud"` + CloudRegion string `json:"region"` + CloudRegions []CloudRegionControllerPriority `json:"cloud-regions"` + AgentVersion string `json:"agent-version"` + PublicAddress string `json:"public-address"` + Deprecated bool `json:"deprecated,omitempty"` env *Environment dbo dbmodel.Controller @@ -428,6 +429,7 @@ func (ctl *Controller) DBObject(c Tester, db *db.Database) dbmodel.Controller { ctl.dbo.AgentVersion = ctl.AgentVersion ctl.dbo.CloudName = ctl.Cloud ctl.dbo.CloudRegion = ctl.CloudRegion + ctl.dbo.PublicAddress = ctl.PublicAddress ctl.dbo.CloudRegions = make([]dbmodel.CloudRegionControllerPriority, len(ctl.CloudRegions)) ctl.dbo.Deprecated = ctl.Deprecated for i, cr := range ctl.CloudRegions { diff --git a/internal/testutils/jimmtest/mocks/jimm_ssh_mock.go b/internal/testutils/jimmtest/mocks/jimm_ssh_mock.go index 9acae284f..91b7c35ea 100644 --- a/internal/testutils/jimmtest/mocks/jimm_ssh_mock.go +++ b/internal/testutils/jimmtest/mocks/jimm_ssh_mock.go @@ -6,13 +6,14 @@ import ( "context" "github.com/canonical/jimm/v3/internal/errors" + "github.com/canonical/jimm/v3/internal/jimm/ssh" "github.com/canonical/jimm/v3/internal/openfga" ) // SSHManager is an implementation of the SshManager interface. type SSHManager struct { - PublicKeyHandler_ func(ctx context.Context, claimUser string, key []byte) (*openfga.User, error) - ResolveAddressesFromModelUUID_ func(ctx context.Context, modelUUID string) ([]string, error) + PublicKeyHandler_ func(ctx context.Context, claimUser string, key []byte) (*openfga.User, error) + ControllerInfoFromModelUUID_ func(ctx context.Context, modelUUID string, user *openfga.User) (ssh.ControllerInfo, error) } func (j SSHManager) PublicKeyHandler(ctx context.Context, claimUser string, key []byte) (*openfga.User, error) { @@ -22,9 +23,9 @@ func (j SSHManager) PublicKeyHandler(ctx context.Context, claimUser string, key return j.PublicKeyHandler_(ctx, claimUser, key) } -func (j SSHManager) ResolveAddressesFromModelUUID(ctx context.Context, modelUUID string) ([]string, error) { - if j.ResolveAddressesFromModelUUID_ == nil { - return nil, errors.E(errors.CodeNotImplemented) +func (j SSHManager) ControllerInfoFromModelUUID(ctx context.Context, modelUUID string, user *openfga.User) (ssh.ControllerInfo, error) { + if j.ControllerInfoFromModelUUID_ == nil { + return ssh.ControllerInfo{}, errors.E(errors.CodeNotImplemented) } - return j.ResolveAddressesFromModelUUID_(ctx, modelUUID) + return j.ControllerInfoFromModelUUID_(ctx, modelUUID, user) } diff --git a/internal/testutils/jimmtest/mocks/jwt_service_mock.go b/internal/testutils/jimmtest/mocks/jwt_service_mock.go new file mode 100644 index 000000000..516f09678 --- /dev/null +++ b/internal/testutils/jimmtest/mocks/jwt_service_mock.go @@ -0,0 +1,20 @@ +// Copyright 2025 Canonical. +package mocks + +import ( + "context" + + "github.com/canonical/jimm/v3/internal/errors" + "github.com/canonical/jimm/v3/internal/jimmjwx" +) + +type JWTService struct { + NewJWT_ func(context.Context, jimmjwx.JWTParams) ([]byte, error) +} + +func (j JWTService) NewJWT(ctx context.Context, params jimmjwx.JWTParams) ([]byte, error) { + if j.NewJWT_ == nil { + return nil, errors.E(errors.CodeNotImplemented) + } + return j.NewJWT_(ctx, params) +}