Skip to content

Commit

Permalink
fix: login and registration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Benehiko committed Jul 28, 2023
1 parent 6245457 commit a0d91a6
Show file tree
Hide file tree
Showing 16 changed files with 169 additions and 71 deletions.
4 changes: 4 additions & 0 deletions driver/registry_default_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ func (m *RegistryDefault) PostLoginHooks(ctx context.Context, credentialsType id
}
}

if credentialsType == identity.CredentialsTypeCodeAuth && m.Config().SelfServiceCodeStrategy(ctx).LoginEnabled {
b = append(b, m.HookCodeAddressVerifier())
}

if len(b) == 0 {
// since we don't want merging hooks defined in a specific strategy and global hooks
// global hooks are added only if no strategy specific hooks are defined
Expand Down
11 changes: 11 additions & 0 deletions persistence/sql/persister_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,14 @@ func (p *Persister) DeleteLoginCodesOfFlow(ctx context.Context, flowID uuid.UUID
//#nosec G201 -- TableName is static
return p.GetConnection(ctx).RawQuery(fmt.Sprintf("DELETE FROM %s WHERE selfservice_login_flow_id = ? AND nid = ?", new(code.LoginCode).TableName(ctx)), flowID, p.NetworkID(ctx)).Exec()
}

func (p *Persister) GetUsedLoginCode(ctx context.Context, flowID uuid.UUID) (*code.LoginCode, error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetUsedLoginCode")
defer span.End()

var loginCode code.LoginCode
if err := p.Connection(ctx).RawQuery(fmt.Sprintf("SELECT * FROM %s WHERE selfservice_login_flow_id = ? AND used_at IS NOT NULL AND nid = ?", new(code.LoginCode).TableName(ctx)), flowID, p.NetworkID(ctx)).First(&loginCode); err != nil {
return nil, sqlcon.HandleError(err)
}
return &loginCode, nil
}
8 changes: 7 additions & 1 deletion persistence/sql/persister_registration.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"time"

"github.com/bxcodec/faker/v3/support/slice"
"github.com/gobuffalo/pop/v6"
"github.com/gofrs/uuid"
"github.com/pkg/errors"
Expand Down Expand Up @@ -127,7 +128,7 @@ func (p *Persister) UseRegistrationCode(ctx context.Context, flowID uuid.UUID, r
}

var registrationCodes []code.RegistrationCode
if err := sqlcon.HandleError(tx.Where("nid = ? AND selfservice_registration_flow_id = ? AND address IN ?", nid, flowID, addresses).All(&registrationCodes)); err != nil {
if err := sqlcon.HandleError(tx.Where("nid = ? AND selfservice_registration_flow_id = ?", nid, flowID).All(&registrationCodes)); err != nil {
if errors.Is(err, sqlcon.ErrNoRows) {
// Return no error, as that would roll back the transaction
return nil
Expand Down Expand Up @@ -173,6 +174,11 @@ func (p *Persister) UseRegistrationCode(ctx context.Context, flowID uuid.UUID, r
return nil, code.ErrCodeAlreadyUsed
}

// ensure that the identifiers extracted from the traits are contained in the registration code
if !slice.Contains(addresses, registrationCode.Address) {
return nil, code.ErrCodeNotFound
}

return registrationCode, nil
}

Expand Down
1 change: 0 additions & 1 deletion selfservice/flow/login/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ func (e *HookExecutor) PostLoginHook(
x.SecureRedirectAllowSelfServiceURLs(c.SelfPublicURL(r.Context())),
x.SecureRedirectOverrideDefaultReturnTo(c.SelfServiceFlowLoginReturnTo(r.Context(), a.Active.String())),
)

if err != nil {
return err
}
Expand Down
1 change: 1 addition & 0 deletions selfservice/flow/login/hook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ func TestLoginExecutor(t *testing.T) {
identity.CredentialsTypeTOTP,
identity.CredentialsTypeWebAuthn,
identity.CredentialsTypeLookup,
identity.CredentialsTypeCodeAuth,
} {
strategy := strategy

Expand Down
32 changes: 31 additions & 1 deletion selfservice/hook/code_address_verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ import (

"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/identity"
"github.com/ory/kratos/selfservice/flow/login"
"github.com/ory/kratos/selfservice/flow/registration"
"github.com/ory/kratos/selfservice/flow/verification"
"github.com/ory/kratos/selfservice/strategy/code"
"github.com/ory/kratos/session"
"github.com/ory/kratos/ui/node"
"github.com/ory/kratos/x"
)

Expand All @@ -25,6 +27,7 @@ type (
verification.StrategyProvider
verification.FlowPersistenceProvider
code.RegistrationCodePersistenceProvider
code.LoginCodePersistenceProvider
identity.PrivilegedPoolProvider
x.WriterProvider
}
Expand All @@ -33,12 +36,39 @@ type (
}
)

var _ registration.PostHookPostPersistExecutor = new(Verifier)
var (
_ registration.PostHookPostPersistExecutor = new(CodeAddressVerifier)
_ login.PostHookExecutor = new(CodeAddressVerifier)
)

func NewCodeAddressVerifier(r codeAddressDependencies) *CodeAddressVerifier {
return &CodeAddressVerifier{r: r}
}

func (cv *CodeAddressVerifier) ExecuteLoginPostHook(_ http.ResponseWriter, r *http.Request, _ node.UiNodeGroup, f *login.Flow, s *session.Session) error {
if f.Active != identity.CredentialsTypeCodeAuth {
return nil
}

loginCode, err := cv.r.LoginCodePersister().GetUsedLoginCode(r.Context(), f.GetID())
if err != nil {
return errors.WithStack(err)
}

for idx := range s.Identity.VerifiableAddresses {
va := s.Identity.VerifiableAddresses[idx]
if !va.Verified && loginCode.Address == va.Value {
va.Verified = true
va.Status = identity.VerifiableAddressStatusCompleted
if err := cv.r.PrivilegedIdentityPool().UpdateVerifiableAddress(r.Context(), &va); err != nil {
return errors.WithStack(err)
}
break
}
}
return nil
}

func (cv *CodeAddressVerifier) ExecutePostRegistrationPostPersistHook(w http.ResponseWriter, r *http.Request, a *registration.Flow, s *session.Session) error {
if a.Active != identity.CredentialsTypeCodeAuth {
return nil
Expand Down
8 changes: 6 additions & 2 deletions selfservice/strategy/code/code_sender.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ func (s *Sender) SendCode(ctx context.Context, f flow.Flow, id *identity.Identit
WithSensitiveField("registration_code", rawCode).
Info("Sending out registration email with code.")

return s.send(ctx, string(address.Via), email.NewRegistrationCodeValid(s.deps, &emailModel))
if err := s.send(ctx, string(address.Via), email.NewRegistrationCodeValid(s.deps, &emailModel)); err != nil {
return errors.WithStack(err)
}

case flow.LoginFlow:
code, err := s.deps.
Expand Down Expand Up @@ -137,7 +139,9 @@ func (s *Sender) SendCode(ctx context.Context, f flow.Flow, id *identity.Identit
WithSensitiveField("login_code", rawCode).
Info("Sending out login email with code.")

return s.send(ctx, string(address.Via), email.NewLoginCodeValid(s.deps, &emailModel))
if err := s.send(ctx, string(address.Via), email.NewLoginCodeValid(s.deps, &emailModel)); err != nil {
return errors.WithStack(err)
}

default:
return errors.WithStack(errors.New("received unknown flow type"))
Expand Down
1 change: 1 addition & 0 deletions selfservice/strategy/code/persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,6 @@ type (
CreateLoginCode(context.Context, *CreateLoginCodeParams) (*LoginCode, error)
UseLoginCode(ctx context.Context, flowID uuid.UUID, identityID uuid.UUID, code string) (*LoginCode, error)
DeleteLoginCodesOfFlow(ctx context.Context, flowID uuid.UUID) error
GetUsedLoginCode(ctx context.Context, flowID uuid.UUID) (*LoginCode, error)
}
)
12 changes: 10 additions & 2 deletions selfservice/strategy/code/strategy_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"encoding/json"
"net/http"
"strings"

"github.com/gofrs/uuid"
"github.com/pkg/errors"
Expand Down Expand Up @@ -139,17 +140,24 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow,
// Step 1: Get the identity
i, cred, err := s.getIdentity(ctx, p.Identifier)
if err != nil {
return errors.WithStack(err)
return err
}

// Step 2: Delete any previous login codes for this flow ID
if err := s.deps.LoginCodePersister().DeleteLoginCodesOfFlow(ctx, f.ID); err != nil {
return errors.WithStack(err)
}

var identifier string
for _, id := range cred.Identifiers {
if strings.EqualFold(p.Identifier, id) {
identifier = id
}
}

addresse := []Address{
{
To: p.Identifier,
To: identifier,
Via: identity.CodeAddressType(string(cred.IdentifierAddressType)),
},
}
Expand Down
1 change: 1 addition & 0 deletions selfservice/strategy/code/strategy_login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ func TestLoginCodeStrategy(t *testing.T) {
"csrf_token": {csrfToken},
"method": {"code"},
"code": {loginCode},
"identifier": {loginEmail},
}.Encode()))
require.NoError(t, err)

Expand Down
22 changes: 5 additions & 17 deletions selfservice/strategy/code/strategy_registration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,30 +169,20 @@ func TestRegistrationCodeStrategy(t *testing.T) {
require.NoError(t, err)
csrfToken := gjson.GetBytes(body, "ui.nodes.#(attributes.name==csrf_token).attributes.value").String()
require.NotEmptyf(t, csrfToken, "%s", body)
require.Equal(t, email, gjson.GetBytes(body, "ui.nodes.#(attributes.name==traits.email).attributes.value").String())
}

// ory_kratos_continuity cookie is set to keep the state between the initial and the follow-up request
// since we cannot persist the identity until the code has been entered and verified, we keep the state
// within the cookie
var continuityCookie *http.Cookie
for _, c := range resp.Cookies() {
if strings.EqualFold(c.Name, "ory_kratos_continuity") {
continuityCookie = c
break
}
}
require.NotNil(t, continuityCookie)
require.NotEmpty(t, continuityCookie.Value)
require.NoError(t, resp.Body.Close())

return s
}

submitOTP := func(t *testing.T, s *state, otp string, shouldHaveSessionCookie bool) *state {
req, err := http.NewRequestWithContext(ctx, "POST", public.URL+registration.RouteSubmitFlow+"?flow="+s.flowID, strings.NewReader(url.Values{
"csrf_token": {s.csrfToken},
"method": {"code"},
"code": {otp},
"csrf_token": {s.csrfToken},
"method": {"code"},
"code": {otp},
"traits.email": {s.email},
}.Encode()))
require.NoError(t, err)

Expand Down Expand Up @@ -230,7 +220,6 @@ func TestRegistrationCodeStrategy(t *testing.T) {
}

t.Run("case=should be able to register with code identity credentials", func(t *testing.T) {

// 1. Initiate flow
state := createRegistrationFlow(t)

Expand Down Expand Up @@ -263,7 +252,6 @@ func TestRegistrationCodeStrategy(t *testing.T) {
require.NoError(t, err)
require.Contains(t, gjson.GetBytes(body, "ui.messages").String(), "Could not find any login identifiers")
}))

})

t.Run("case=should have verifiable address even if after session hook is disabled", func(t *testing.T) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ context("Registration success with code method", () => {
cy.get("button[name=method][value=code]").click()
})

cy.deleteMail({ atLeast: 1 })

cy.visit(login)
cy.get(
"form[data-testid='login-flow-code'] input[name=identifier]",
Expand All @@ -66,6 +68,9 @@ context("Registration success with code method", () => {
)
cy.get("button[name=method][value=code]").click()
})

cy.deleteMail({ atLeast: 1 })

if (app === "express") {
cy.get('a[href*="sessions"').click()
}
Expand Down Expand Up @@ -101,6 +106,8 @@ context("Registration success with code method", () => {
cy.get("button[name=method][value=code]").click()
})

cy.deleteMail({ atLeast: 1 })

if (app === "express") {
cy.get('a[href*="sessions"').click()
}
Expand All @@ -120,16 +127,7 @@ context("Registration success with code method", () => {
{
hook: "session",
},
{
hook: "show_verification_ui",
},
])
cy.setupHooks("login", "after", "code", [
{
hook: "require_verified_address",
},
])
cy.enableVerification()

// Setup complex schema
cy.setIdentitySchema(
Expand Down Expand Up @@ -158,12 +156,18 @@ context("Registration success with code method", () => {

// intentionally use email 1 to verify the account
cy.url().should("contain", "registration")
cy.getRegistrationCodeFromEmail(email).then((code) => {
cy.get(
"form[data-testid='registration-flow-code'] input[name=code]",
).type(code)
cy.get("button[name=method][value=code]").click()
})
cy.getRegistrationCodeFromEmail(email, { expectedCount: 2 }).then(
(code) => {
cy.get(
"form[data-testid='registration-flow-code'] input[name=code]",
).type(code)
cy.get("button[name=method][value=code]").click()
},
)

cy.deleteMail({ atLeast: 2 })

cy.logout()

// Attempt to sign in with email 2 (should fail)
cy.visit(login)
Expand All @@ -182,11 +186,20 @@ context("Registration success with code method", () => {
if (app === "express") {
cy.get('a[href*="sessions"').click()
}

cy.getSession().should((session) => {
console.dir({ session })
const { identity } = session
expect(identity.id).to.not.be.empty
expect(identity.verifiable_addresses).to.have.length(1)
expect(identity.verifiable_addresses[0].status).to.equal("completed")
expect(identity.verifiable_addresses).to.have.length(2)
expect(
identity.verifiable_addresses.filter((v) => v.value === email)[0]
.status,
).to.equal("completed")
expect(
identity.verifiable_addresses.filter((v) => v.value === email2)[0]
.status,
).to.equal("completed")
expect(identity.traits.email).to.equal(email)
})
})
Expand Down
Loading

0 comments on commit a0d91a6

Please sign in to comment.