`)
})
@@ -112,9 +112,10 @@ func Test_ReceiverRegistrationHandler_ServeHTTP(t *testing.T) {
"mywallet://")
receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{})
receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.DraftReceiversWalletStatus)
- receiverWallet.StellarAddress = "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444"
- receiverWallet.StellarMemo = ""
- err = receiverWalletModel.UpdateReceiverWallet(ctx, *receiverWallet, dbConnectionPool)
+ err = receiverWalletModel.Update(ctx, receiverWallet.ID, data.ReceiverWalletUpdate{
+ StellarAddress: "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444",
+ StellarMemo: "",
+ }, dbConnectionPool)
require.NoError(t, err)
t.Run("returns 200 - Ok (And show the Registration Success page) if the token is in the request context and it's valid and the user was already registered 🎉", func(t *testing.T) {
@@ -168,7 +169,7 @@ func Test_ReceiverRegistrationHandler_ServeHTTP(t *testing.T) {
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "text/html; charset=utf-8", resp.Header.Get("Content-Type"))
assert.Contains(t, string(respBody), "
`)
+ assert.Contains(t, string(respBody), `
`)
assert.Contains(t, string(respBody), `
`)
assert.Contains(t, string(respBody), `
Your data is processed by MyCustomAid in accordance with their Privacy Policy
`)
})
diff --git a/internal/serve/httphandler/receiver_send_otp_handler.go b/internal/serve/httphandler/receiver_send_otp_handler.go
index b35e0e019..f210662e8 100644
--- a/internal/serve/httphandler/receiver_send_otp_handler.go
+++ b/internal/serve/httphandler/receiver_send_otp_handler.go
@@ -1,6 +1,7 @@
package httphandler
import (
+ "context"
"encoding/json"
"errors"
"fmt"
@@ -19,13 +20,15 @@ import (
"github.com/stellar/stellar-disbursement-platform-backend/internal/utils"
)
+type OTPRegistrationType string
+
// OTPMessageDisclaimer contains disclaimer text that needs to be added as part of the OTP message to remind the
// receiver how sensitive the data is.
const OTPMessageDisclaimer = " If you did not request this code, please ignore. Do not share your code with anyone."
type ReceiverSendOTPHandler struct {
Models *data.Models
- SMSMessengerClient message.MessengerClient
+ MessageDispatcher message.MessageDispatcherInterface
ReCAPTCHAValidator validators.ReCAPTCHAValidator
}
@@ -36,24 +39,50 @@ type ReceiverSendOTPData struct {
type ReceiverSendOTPRequest struct {
PhoneNumber string `json:"phone_number"`
+ Email string `json:"email"`
ReCAPTCHAToken string `json:"recaptcha_token"`
}
+// validateContactInfo validates the contact information provided in the ReceiverSendOTPRequest. It ensures that either
+// the phone number or email is provided, but not both. It also validates the phone number and email format.
+func (r ReceiverSendOTPRequest) validateContactInfo() validators.Validator {
+ v := *validators.NewValidator()
+ r.Email = utils.TrimAndLower(r.Email)
+ r.PhoneNumber = utils.TrimAndLower(r.PhoneNumber)
+
+ switch {
+ case r.PhoneNumber == "" && r.Email == "":
+ v.Check(false, "phone_number", "phone_number or email is required")
+ v.Check(false, "email", "phone_number or email is required")
+ case r.PhoneNumber != "" && r.Email != "":
+ v.Check(false, "phone_number", "phone_number and email cannot be both provided")
+ v.Check(false, "email", "phone_number and email cannot be both provided")
+ case r.PhoneNumber != "":
+ v.CheckError(utils.ValidatePhoneNumber(r.PhoneNumber), "phone_number", "")
+ case r.Email != "":
+ v.CheckError(utils.ValidateEmail(r.Email), "email", "")
+ }
+
+ return v
+}
+
type ReceiverSendOTPResponseBody struct {
- Message string `json:"message"`
- VerificationField data.VerificationField `json:"verification_field"`
+ Message string `json:"message"`
+ VerificationField data.VerificationType `json:"verification_field"`
}
func (h ReceiverSendOTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
+ // Parse request body
receiverSendOTPRequest := ReceiverSendOTPRequest{}
-
err := json.NewDecoder(r.Body).Decode(&receiverSendOTPRequest)
if err != nil {
httperror.BadRequest("invalid request body", err, nil).Render(w)
return
}
+ receiverSendOTPRequest.PhoneNumber = utils.TrimAndLower(receiverSendOTPRequest.PhoneNumber)
+ receiverSendOTPRequest.Email = utils.TrimAndLower(receiverSendOTPRequest.Email)
// validating reCAPTCHA Token
isValid, err := h.ReCAPTCHAValidator.IsTokenValid(ctx, receiverSendOTPRequest.ReCAPTCHAToken)
@@ -61,26 +90,13 @@ func (h ReceiverSendOTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request
httperror.InternalError(ctx, "Cannot validate reCAPTCHA token", err, nil).Render(w)
return
}
-
if !isValid {
log.Ctx(ctx).Errorf("reCAPTCHA token is invalid")
httperror.BadRequest("reCAPTCHA token is invalid", nil, nil).Render(w)
return
}
- truncatedPhoneNumber := utils.TruncateString(receiverSendOTPRequest.PhoneNumber, 3)
- if phoneValidateErr := utils.ValidatePhoneNumber(receiverSendOTPRequest.PhoneNumber); phoneValidateErr != nil {
- extras := map[string]interface{}{"phone_number": "phone_number is required"}
- if !errors.Is(phoneValidateErr, utils.ErrEmptyPhoneNumber) {
- phoneValidateErr = fmt.Errorf("validating phone number %s: %w", truncatedPhoneNumber, phoneValidateErr)
- log.Ctx(ctx).Error(phoneValidateErr)
- extras["phone_number"] = "invalid phone number provided"
- }
- httperror.BadRequest("request invalid", phoneValidateErr, extras).Render(w)
- return
- }
-
- // Get clains from SEP24 JWT
+ // Validate SEP-24 JWT claims
sep24Claims := anchorplatform.GetSEP24Claims(ctx)
if sep24Claims == nil {
err = fmt.Errorf("no SEP-24 claims found in the request context")
@@ -88,7 +104,6 @@ func (h ReceiverSendOTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request
httperror.Unauthorized("", err, nil).Render(w)
return
}
-
err = sep24Claims.Valid()
if err != nil {
err = fmt.Errorf("SEP-24 claims are invalid: %w", err)
@@ -97,76 +112,136 @@ func (h ReceiverSendOTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request
return
}
- verificationField := data.VerificationFieldDateOfBirth
- receiverVerification, err := h.Models.ReceiverVerification.GetLatestByPhoneNumber(ctx, receiverSendOTPRequest.PhoneNumber)
- if err != nil {
- err = fmt.Errorf("cannot find latest receiver verification for phone number %s: %w", truncatedPhoneNumber, err)
- log.Ctx(ctx).Error(err)
+ // Ensure XOR(PhoneNumber, Email)
+ if v := receiverSendOTPRequest.validateContactInfo(); v.HasErrors() {
+ httperror.BadRequest("", nil, v.Errors).Render(w)
+ return
+ }
+
+ // Determine the contact type and handle accordingly
+ var contactType data.ReceiverContactType
+ var contactInfo string
+ if receiverSendOTPRequest.PhoneNumber != "" {
+ contactType, contactInfo = data.ReceiverContactTypeSMS, receiverSendOTPRequest.PhoneNumber
+ } else if receiverSendOTPRequest.Email != "" {
+ contactType, contactInfo = data.ReceiverContactTypeEmail, receiverSendOTPRequest.Email
} else {
- verificationField = receiverVerification.VerificationField
+ httperror.InternalError(ctx, "Unexpected contact info", nil, nil).Render(w)
+ return
+ }
+ verificationField, httpErr := h.handleOTPForReceiver(ctx, contactType, contactInfo, sep24Claims.ClientDomainClaim)
+ if httpErr != nil {
+ httpErr.Render(w)
+ return
+ }
+
+ response := newReceiverSendOTPResponseBody(contactType, verificationField)
+ httpjson.RenderStatus(w, http.StatusOK, response, httpjson.JSON)
+}
+
+// newReceiverSendOTPResponseBody creates a new ReceiverSendOTPResponseBody based on the OTP registration type and verification field.
+func newReceiverSendOTPResponseBody(contactType data.ReceiverContactType, verificationField data.VerificationType) ReceiverSendOTPResponseBody {
+ resp := ReceiverSendOTPResponseBody{VerificationField: verificationField}
+
+ switch contactType {
+ case data.ReceiverContactTypeSMS:
+ resp.Message = "if your phone number is registered, you'll receive an OTP"
+ case data.ReceiverContactTypeEmail:
+ resp.Message = "if your email is registered, you'll receive an OTP"
+ }
+
+ return resp
+}
+
+// handleOTPReceiver handles the OTP generation and sending for a receiver with the provided contactType and contactInfo.
+func (h ReceiverSendOTPHandler) handleOTPForReceiver(
+ ctx context.Context,
+ contactType data.ReceiverContactType,
+ contactInfo string,
+ sep24ClientDomain string,
+) (data.VerificationType, *httperror.HTTPError) {
+ var err error
+ placeholderVerificationField := data.VerificationTypeDateOfBirth
+ truncatedContactInfo := utils.TruncateString(contactInfo, 3)
+ contactTypeStr := utils.Humanize(string(contactType))
+
+ // get receiverVerification by that value of contactInfo
+ receiverVerification, err := h.Models.ReceiverVerification.GetLatestByContactInfo(ctx, contactInfo)
+ if err != nil {
+ log.Ctx(ctx).Warnf("Could not find ANY receiver verification for %s %s: %v", contactTypeStr, truncatedContactInfo, err)
+ return placeholderVerificationField, nil
}
// Generate a new 6 digits OTP
newOTP, err := utils.RandomString(6, utils.NumberBytes)
if err != nil {
- httperror.InternalError(ctx, "Cannot generate OTP for receiver wallet", err, nil).Render(w)
- return
+ return placeholderVerificationField, httperror.InternalError(ctx, "Cannot generate OTP for receiver wallet", err, nil)
+ }
+
+ // Update OTP for receiver wallet
+ numberOfUpdatedRows, err := h.Models.ReceiverWallet.UpdateOTPByReceiverContactInfoAndWalletDomain(ctx, contactInfo, sep24ClientDomain, newOTP)
+ if err != nil && !errors.Is(err, data.ErrRecordNotFound) {
+ return placeholderVerificationField, httperror.InternalError(ctx, "Cannot update OTP for receiver wallet", err, nil)
+ }
+ if numberOfUpdatedRows < 1 {
+ log.Ctx(ctx).Warnf("Could not find a match between %s (%s) and client domain (%s)", contactTypeStr, truncatedContactInfo, sep24ClientDomain)
+ return placeholderVerificationField, nil
}
+ // Send OTP message
+ err = h.sendOTP(ctx, contactType, contactInfo, newOTP)
+ if err != nil {
+ err = fmt.Errorf("sending OTP message: %w", err)
+ return placeholderVerificationField, httperror.InternalError(ctx, "Failed to send OTP message, reason: "+err.Error(), err, nil)
+ }
+
+ return receiverVerification.VerificationField, nil
+}
+
+// sendOTP sends an OTP through the provided contact type to the provided contact information.
+func (h ReceiverSendOTPHandler) sendOTP(ctx context.Context, contactType data.ReceiverContactType, contactInfo, otp string) error {
organization, err := h.Models.Organizations.Get(ctx)
if err != nil {
- httperror.InternalError(ctx, "Cannot get organization", err, nil).Render(w)
- return
+ return fmt.Errorf("cannot get organization: %w", err)
}
- numberOfUpdatedRows, err := h.Models.ReceiverWallet.UpdateOTPByReceiverPhoneNumberAndWalletDomain(ctx, receiverSendOTPRequest.PhoneNumber, sep24Claims.ClientDomainClaim, newOTP)
+ otpMessageTemplate := organization.OTPMessageTemplate + OTPMessageDisclaimer
+ if !strings.Contains(organization.OTPMessageTemplate, "{{.OTP}}") {
+ // Adding the OTP code to the template
+ otpMessageTemplate = fmt.Sprintf(`{{.OTP}} %s`, strings.TrimSpace(otpMessageTemplate))
+ }
+
+ sendOTPMessageTpl, err := template.New("").Parse(otpMessageTemplate)
if err != nil {
- httperror.InternalError(ctx, "Cannot update OTP for receiver wallet", err, nil).Render(w)
- return
+ return fmt.Errorf("cannot parse OTP template: %w", err)
}
- if numberOfUpdatedRows < 1 {
- log.Ctx(ctx).Warnf("updated no rows in ReceiverSendOTPHandler, please verify if the provided phone number (%s) and client_domain (%s) are both valid", truncatedPhoneNumber, sep24Claims.ClientDomainClaim)
- } else {
- sendOTPData := ReceiverSendOTPData{
- OTP: newOTP,
- OrganizationName: organization.Name,
- }
-
- otpMessageTemplate := organization.OTPMessageTemplate + OTPMessageDisclaimer
- if !strings.Contains(organization.OTPMessageTemplate, "{{.OTP}}") {
- // Adding the OTP code to the template
- otpMessageTemplate = fmt.Sprintf(`{{.OTP}} %s`, strings.TrimSpace(otpMessageTemplate))
- }
-
- sendOTPMessageTpl, err := template.New("").Parse(otpMessageTemplate)
- if err != nil {
- httperror.InternalError(ctx, "Cannot parse OTP template", err, nil).Render(w)
- return
- }
-
- builder := new(strings.Builder)
- if err = sendOTPMessageTpl.Execute(builder, sendOTPData); err != nil {
- httperror.InternalError(ctx, "Cannot execute OTP template", err, nil).Render(w)
- return
- }
-
- smsMessage := message.Message{
- ToPhoneNumber: receiverSendOTPRequest.PhoneNumber,
- Message: builder.String(),
- }
-
- log.Ctx(ctx).Infof("sending OTP message to phone number: %s", truncatedPhoneNumber)
- err = h.SMSMessengerClient.SendMessage(smsMessage)
- if err != nil {
- httperror.InternalError(ctx, "Cannot send OTP message", err, nil).Render(w)
- return
- }
- }
-
- response := ReceiverSendOTPResponseBody{
- Message: "if your phone number is registered, you'll receive an OTP",
- VerificationField: verificationField,
+ sendOTPData := ReceiverSendOTPData{
+ OTP: otp,
+ OrganizationName: organization.Name,
}
- httpjson.RenderStatus(w, http.StatusOK, response, httpjson.JSON)
+
+ builder := new(strings.Builder)
+ if err = sendOTPMessageTpl.Execute(builder, sendOTPData); err != nil {
+ return fmt.Errorf("cannot execute OTP template: %w", err)
+ }
+
+ msg := message.Message{Body: builder.String()}
+ switch contactType {
+ case data.ReceiverContactTypeSMS:
+ msg.ToPhoneNumber = contactInfo
+ case data.ReceiverContactTypeEmail:
+ msg.ToEmail = contactInfo
+ msg.Title = "Your One-Time Password: " + otp
+ }
+
+ truncatedContactInfo := utils.TruncateString(contactInfo, 3)
+ contactTypeStr := utils.Humanize(string(contactType))
+ log.Ctx(ctx).Infof("sending OTP message to %s %s...", contactTypeStr, truncatedContactInfo)
+ _, err = h.MessageDispatcher.SendMessage(ctx, msg, organization.MessageChannelPriority)
+ if err != nil {
+ return fmt.Errorf("cannot send OTP message through %s to %s: %w", contactTypeStr, truncatedContactInfo, err)
+ }
+
+ return nil
}
diff --git a/internal/serve/httphandler/receiver_send_otp_handler_test.go b/internal/serve/httphandler/receiver_send_otp_handler_test.go
index 5ef742ed5..dc39ca696 100644
--- a/internal/serve/httphandler/receiver_send_otp_handler_test.go
+++ b/internal/serve/httphandler/receiver_send_otp_handler_test.go
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"errors"
+ "fmt"
"io"
"net/http"
"net/http/httptest"
@@ -14,6 +15,8 @@ import (
"github.com/go-chi/chi/v5"
"github.com/golang-jwt/jwt/v4"
+ "github.com/sirupsen/logrus"
+ "github.com/stellar/go/support/log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
@@ -23,359 +26,694 @@ import (
"github.com/stellar/stellar-disbursement-platform-backend/internal/anchorplatform"
"github.com/stellar/stellar-disbursement-platform-backend/internal/data"
"github.com/stellar/stellar-disbursement-platform-backend/internal/message"
+ "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror"
"github.com/stellar/stellar-disbursement-platform-backend/internal/serve/validators"
+ "github.com/stellar/stellar-disbursement-platform-backend/internal/utils"
)
-func Test_ReceiverSendOTPHandler_ServeHTTP(t *testing.T) {
- r := chi.NewRouter()
+func Test_ReceiverSendOTPRequest_validateContactInfo(t *testing.T) {
+ testCases := []struct {
+ name string
+ receiverSendOTPRequest ReceiverSendOTPRequest
+ wantValidationErrors map[string]interface{}
+ }{
+ {
+ name: "🔴 phone number and email both empty",
+ receiverSendOTPRequest: ReceiverSendOTPRequest{
+ PhoneNumber: "",
+ Email: "",
+ },
+ wantValidationErrors: map[string]interface{}{
+ "phone_number": "phone_number or email is required",
+ "email": "phone_number or email is required",
+ },
+ },
+ {
+ name: "🔴 phone number and email both provided",
+ receiverSendOTPRequest: ReceiverSendOTPRequest{
+ PhoneNumber: "+141555550000",
+ Email: "foobar@test.com",
+ },
+ wantValidationErrors: map[string]interface{}{
+ "phone_number": "phone_number and email cannot be both provided",
+ "email": "phone_number and email cannot be both provided",
+ },
+ },
+ {
+ name: "🔴 phone number is invalid",
+ receiverSendOTPRequest: ReceiverSendOTPRequest{
+ PhoneNumber: "invalid",
+ },
+ wantValidationErrors: map[string]interface{}{
+ "phone_number": "the provided phone number is not a valid E.164 number",
+ },
+ },
+ {
+ name: "🔴 email is invalid",
+ receiverSendOTPRequest: ReceiverSendOTPRequest{
+ Email: "invalid",
+ },
+ wantValidationErrors: map[string]interface{}{
+ "email": "the provided email is not valid",
+ },
+ },
+ {
+ name: "🟢 phone number is valid",
+ receiverSendOTPRequest: ReceiverSendOTPRequest{
+ PhoneNumber: "+14155550000",
+ },
+ wantValidationErrors: nil,
+ },
+ {
+ name: "🟢 email is valid",
+ receiverSendOTPRequest: ReceiverSendOTPRequest{
+ Email: "foobar@test.com",
+ },
+ wantValidationErrors: nil,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v := tc.receiverSendOTPRequest.validateContactInfo()
+ if len(tc.wantValidationErrors) == 0 {
+ assert.Len(t, v.Errors, 0)
+ } else {
+ assert.Equal(t, tc.wantValidationErrors, v.Errors)
+ }
+ })
+ }
+}
+
+func Test_ReceiverSendOTPHandler_ServeHTTP_validation(t *testing.T) {
dbt := dbtest.Open(t)
defer dbt.Close()
-
dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN)
require.NoError(t, err)
defer dbConnectionPool.Close()
+ ctx := context.Background()
models, err := data.NewModels(dbConnectionPool)
require.NoError(t, err)
- ctx := context.Background()
-
- phoneNumber := "+380443973607"
- receiver1 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{PhoneNumber: phoneNumber})
- receiver2 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{})
- wallet1 := data.CreateWalletFixture(t, ctx, dbConnectionPool, "testWallet", "https://home.page", "home.page", "wallet123://")
- data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{
- ReceiverID: receiver1.ID,
- VerificationField: data.VerificationFieldDateOfBirth,
- })
-
- _ = data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver1.ID, wallet1.ID, data.RegisteredReceiversWalletStatus)
- _ = data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver2.ID, wallet1.ID, data.RegisteredReceiversWalletStatus)
-
- mockMessenger := message.MessengerClientMock{}
- reCAPTCHAValidator := &validators.ReCAPTCHAValidatorMock{}
-
- r.Post("/wallet-registration/otp", ReceiverSendOTPHandler{
- Models: models,
- SMSMessengerClient: &mockMessenger,
- ReCAPTCHAValidator: reCAPTCHAValidator,
- }.ServeHTTP)
-
- requestSendOTP := ReceiverSendOTPRequest{
- PhoneNumber: receiver1.PhoneNumber,
- ReCAPTCHAToken: "XyZ",
+ validClaims := &anchorplatform.SEP24JWTClaims{
+ ClientDomainClaim: "no-op-domain.test.com",
+ RegisteredClaims: jwt.RegisteredClaims{
+ ID: "test-transaction-id",
+ Subject: "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444",
+ ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)),
+ },
}
- reqBody, err := json.Marshal(requestSendOTP)
- require.NoError(t, err)
-
- t.Run("returns 401 - Unauthorized if the token is not in the request context", func(t *testing.T) {
- reCAPTCHAValidator.
- On("IsTokenValid", mock.Anything, "XyZ").
- Return(true, nil).
- Once()
- req, err := http.NewRequest(http.MethodPost, "/wallet-registration/otp", strings.NewReader(string(reqBody)))
- require.NoError(t, err)
- rr := httptest.NewRecorder()
- r.ServeHTTP(rr, req)
-
- resp := rr.Result()
- respBody, err := io.ReadAll(resp.Body)
- require.NoError(t, err)
-
- assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
- assert.JSONEq(t, `{"error":"Not authorized."}`, string(respBody))
- })
-
- t.Run("returns 401 - Unauthorized if the token is in the request context but it's not valid", func(t *testing.T) {
- reCAPTCHAValidator.
- On("IsTokenValid", mock.Anything, "XyZ").
- Return(true, nil).
- Once()
- req, err := http.NewRequest(http.MethodPost, "/wallet-registration/otp", strings.NewReader(string(reqBody)))
- require.NoError(t, err)
-
- rr := httptest.NewRecorder()
- invalidClaims := &anchorplatform.SEP24JWTClaims{}
- req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, invalidClaims))
- r.ServeHTTP(rr, req)
-
- resp := rr.Result()
- respBody, err := io.ReadAll(resp.Body)
- require.NoError(t, err)
-
- assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
- assert.JSONEq(t, `{"error":"Not authorized."}`, string(respBody))
- })
-
- t.Run("returns 400 - BadRequest with a wrong request body", func(t *testing.T) {
- reCAPTCHAValidator.
- On("IsTokenValid", mock.Anything, "XyZ").
- Return(true, nil).
- Twice()
- invalidRequest := `{"recaptcha_token": "XyZ"}`
-
- req, err := http.NewRequest(http.MethodPost, "/wallet-registration/otp", strings.NewReader(invalidRequest))
- require.NoError(t, err)
-
- rr := httptest.NewRecorder()
- invalidClaims := &anchorplatform.SEP24JWTClaims{}
- req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, invalidClaims))
- r.ServeHTTP(rr, req)
-
- resp := rr.Result()
-
- respBody, err := io.ReadAll(resp.Body)
- require.NoError(t, err)
- defer resp.Body.Close()
-
- assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
- assert.JSONEq(t, `{"error":"request invalid","extras":{"phone_number":"phone_number is required"}}`, string(respBody))
-
- req, err = http.NewRequest(http.MethodPost, "/wallet-registration/otp", strings.NewReader(`{"phone_number": "+55555555555", "recaptcha_token": "XyZ"}`))
- require.NoError(t, err)
-
- rr = httptest.NewRecorder()
- invalidClaims = &anchorplatform.SEP24JWTClaims{}
- req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, invalidClaims))
- r.ServeHTTP(rr, req)
-
- resp = rr.Result()
-
- respBody, err = io.ReadAll(resp.Body)
- require.NoError(t, err)
- defer resp.Body.Close()
-
- assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
- assert.JSONEq(t, `{"error": "request invalid", "extras": {"phone_number": "invalid phone number provided"}}`, string(respBody))
- })
-
- t.Run("returns 200 - Ok if the token is in the request context and body is valid", func(t *testing.T) {
- reCAPTCHAValidator.
- On("IsTokenValid", mock.Anything, "XyZ").
- Return(true, nil).
- Once()
- req, err := http.NewRequest(http.MethodPost, "/wallet-registration/otp", strings.NewReader(string(reqBody)))
- require.NoError(t, err)
-
- validClaims := &anchorplatform.SEP24JWTClaims{
- ClientDomainClaim: wallet1.SEP10ClientDomain,
- RegisteredClaims: jwt.RegisteredClaims{
- ID: "test-transaction-id",
- Subject: "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444",
- ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)),
+ ctxWithValidSEP24Claims := context.WithValue(ctx, anchorplatform.SEP24ClaimsContextKey, validClaims)
+ invalidClaims := &anchorplatform.SEP24JWTClaims{}
+ ctxWithInvalidSEP24Claims := context.WithValue(ctx, anchorplatform.SEP24ClaimsContextKey, invalidClaims)
+
+ const reCAPTCHAToken = "XyZ"
+
+ testCases := []struct {
+ name string
+ context context.Context
+ receiverSendOTPRequest ReceiverSendOTPRequest
+ prepareMocksFn func(t *testing.T, mockReCAPTCHAValidator *validators.ReCAPTCHAValidatorMock, mockMessageDispatcher *message.MockMessageDispatcher)
+ wantStatusCode int
+ wantBody string
+ }{
+ {
+ name: "(500 - InternalServerError) if the reCAPTCHA validation returns an error",
+ context: ctx,
+ receiverSendOTPRequest: ReceiverSendOTPRequest{ReCAPTCHAToken: "invalid-recaptcha-token"},
+ prepareMocksFn: func(t *testing.T, mockReCAPTCHAValidator *validators.ReCAPTCHAValidatorMock, _ *message.MockMessageDispatcher) {
+ mockReCAPTCHAValidator.
+ On("IsTokenValid", mock.Anything, "invalid-recaptcha-token").
+ Return(false, errors.New("invalid recaptcha")).
+ Once()
},
- }
- req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, validClaims))
-
- mockMessenger.On("SendMessage", mock.AnythingOfType("message.Message")).
- Return(nil).
- Once().
- Run(func(args mock.Arguments) {
- msg := args.Get(0).(message.Message)
- assert.Contains(t, msg.Message, "is your MyCustomAid phone verification code.")
- assert.Regexp(t, regexp.MustCompile(`^\d{6}\s.+$`), msg.Message)
- })
-
- rr := httptest.NewRecorder()
- r.ServeHTTP(rr, req)
-
- resp := rr.Result()
- respBody, err := io.ReadAll(resp.Body)
- require.NoError(t, err)
-
- assert.Equal(t, http.StatusOK, resp.StatusCode)
- assert.Contains(t, resp.Header.Get("Content-Type"), "/json; charset=utf-8")
- assert.JSONEq(t, string(respBody), `{"message":"if your phone number is registered, you'll receive an OTP", "verification_field":"DATE_OF_BIRTH"}`)
- })
-
- t.Run("returns 200 - parses a custom OTP message template successfully", func(t *testing.T) {
- reCAPTCHAValidator.
- On("IsTokenValid", mock.Anything, "XyZ").
- Return(true, nil).
- Once()
- req, err := http.NewRequest(http.MethodPost, "/wallet-registration/otp", strings.NewReader(string(reqBody)))
- require.NoError(t, err)
-
- validClaims := &anchorplatform.SEP24JWTClaims{
- ClientDomainClaim: wallet1.SEP10ClientDomain,
- RegisteredClaims: jwt.RegisteredClaims{
- ID: "test-transaction-id",
- Subject: "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444",
- ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)),
+ wantStatusCode: http.StatusInternalServerError,
+ wantBody: `{"error":"Cannot validate reCAPTCHA token"}`,
+ },
+ {
+ name: "(400 - BadRequest) if the reCAPTCHA token is invalid",
+ context: ctx,
+ receiverSendOTPRequest: ReceiverSendOTPRequest{ReCAPTCHAToken: reCAPTCHAToken},
+ prepareMocksFn: func(t *testing.T, mockReCAPTCHAValidator *validators.ReCAPTCHAValidatorMock, _ *message.MockMessageDispatcher) {
+ mockReCAPTCHAValidator.
+ On("IsTokenValid", mock.Anything, reCAPTCHAToken).
+ Return(false, nil).
+ Once()
},
- }
- req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, validClaims))
-
- // Set a custom message for the OTP message
- customOTPMessage := "Here's your code to complete your registration. MyOrg 👋"
- err = models.Organizations.Update(ctx, &data.OrganizationUpdate{OTPMessageTemplate: &customOTPMessage})
- require.NoError(t, err)
-
- mockMessenger.On("SendMessage", mock.AnythingOfType("message.Message")).
- Return(nil).
- Once().
- Run(func(args mock.Arguments) {
- msg := args.Get(0).(message.Message)
- assert.Contains(t, msg.Message, customOTPMessage)
- assert.Regexp(t, regexp.MustCompile(`^\d{6}\s.+$`), msg.Message)
- })
-
- rr := httptest.NewRecorder()
- r.ServeHTTP(rr, req)
-
- resp := rr.Result()
- respBody, err := io.ReadAll(resp.Body)
- require.NoError(t, err)
-
- assert.Equal(t, http.StatusOK, resp.StatusCode)
- assert.Contains(t, resp.Header.Get("Content-Type"), "/json; charset=utf-8")
- assert.JSONEq(t, string(respBody), `{"message":"if your phone number is registered, you'll receive an OTP", "verification_field":"DATE_OF_BIRTH"}`)
- })
-
- t.Run("returns 500 - InternalServerError when something goes wrong when sending the SMS", func(t *testing.T) {
- reCAPTCHAValidator.
- On("IsTokenValid", mock.Anything, "XyZ").
- Return(true, nil).
- Once()
- req, err := http.NewRequest(http.MethodPost, "/wallet-registration/otp", strings.NewReader(string(reqBody)))
- require.NoError(t, err)
-
- validClaims := &anchorplatform.SEP24JWTClaims{
- ClientDomainClaim: wallet1.SEP10ClientDomain,
- RegisteredClaims: jwt.RegisteredClaims{
- ID: "test-transaction-id",
- Subject: "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444",
- ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)),
+ wantStatusCode: http.StatusBadRequest,
+ wantBody: `{"error":"reCAPTCHA token is invalid"}`,
+ },
+ {
+ name: "(401 - Unauthorized) if the SEP-24 claims are not in the request context",
+ context: ctx,
+ receiverSendOTPRequest: ReceiverSendOTPRequest{ReCAPTCHAToken: reCAPTCHAToken},
+ prepareMocksFn: func(t *testing.T, mockReCAPTCHAValidator *validators.ReCAPTCHAValidatorMock, _ *message.MockMessageDispatcher) {
+ mockReCAPTCHAValidator.
+ On("IsTokenValid", mock.Anything, reCAPTCHAToken).
+ Return(true, nil).
+ Once()
},
- }
- req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, validClaims))
-
- mockMessenger.On("SendMessage", mock.AnythingOfType("message.Message")).
- Return(errors.New("error sending message")).
- Once()
-
- rr := httptest.NewRecorder()
- r.ServeHTTP(rr, req)
-
- resp := rr.Result()
- respBody, err := io.ReadAll(resp.Body)
- require.NoError(t, err)
-
- assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
- assert.Contains(t, resp.Header.Get("Content-Type"), "/json; charset=utf-8")
- assert.JSONEq(t, string(respBody), `{"error":"Cannot send OTP message"}`)
- })
-
- t.Run("returns 500 - InternalServerError when unable to validate recaptcha", func(t *testing.T) {
- reCAPTCHAValidator.
- On("IsTokenValid", mock.Anything, "XyZ").
- Return(false, errors.New("error requesting verify reCAPTCHA token")).
- Once()
-
- req, err := http.NewRequest(http.MethodPost, "/wallet-registration/otp", strings.NewReader(string(reqBody)))
- require.NoError(t, err)
-
- validClaims := &anchorplatform.SEP24JWTClaims{
- ClientDomainClaim: wallet1.SEP10ClientDomain,
- RegisteredClaims: jwt.RegisteredClaims{
- ID: "test-transaction-id",
- Subject: "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444",
- ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)),
+ wantStatusCode: http.StatusUnauthorized,
+ wantBody: `{"error":"Not authorized."}`,
+ },
+ {
+ name: "(401 - Unauthorized) if the SEP-24 claims are invalid",
+ context: ctxWithInvalidSEP24Claims,
+ receiverSendOTPRequest: ReceiverSendOTPRequest{ReCAPTCHAToken: reCAPTCHAToken},
+ prepareMocksFn: func(t *testing.T, mockReCAPTCHAValidator *validators.ReCAPTCHAValidatorMock, _ *message.MockMessageDispatcher) {
+ mockReCAPTCHAValidator.
+ On("IsTokenValid", mock.Anything, reCAPTCHAToken).
+ Return(true, nil).
+ Once()
},
- }
- req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, validClaims))
+ wantStatusCode: http.StatusUnauthorized,
+ wantBody: `{"error":"Not authorized."}`,
+ },
+ {
+ name: "(400 - BadRequest) if the request body is invalid",
+ context: ctxWithValidSEP24Claims,
+ receiverSendOTPRequest: ReceiverSendOTPRequest{ReCAPTCHAToken: reCAPTCHAToken},
+ prepareMocksFn: func(t *testing.T, mockReCAPTCHAValidator *validators.ReCAPTCHAValidatorMock, _ *message.MockMessageDispatcher) {
+ mockReCAPTCHAValidator.
+ On("IsTokenValid", mock.Anything, reCAPTCHAToken).
+ Return(true, nil).
+ Once()
+ },
+ wantStatusCode: http.StatusBadRequest,
+ wantBody: `{
+ "error": "The request was invalid in some way.",
+ "extras": {
+ "phone_number":"phone_number or email is required",
+ "email":"phone_number or email is required"
+ }
+ }`,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ mockReCAPTCHAValidator := validators.NewReCAPTCHAValidatorMock(t)
+ mockMessageDispatcher := message.NewMockMessageDispatcher(t)
+
+ tc.prepareMocksFn(t, mockReCAPTCHAValidator, mockMessageDispatcher)
- w := httptest.NewRecorder()
- r.ServeHTTP(w, req)
+ r := chi.NewRouter()
+ r.Post("/wallet-registration/otp", ReceiverSendOTPHandler{
+ Models: models,
+ MessageDispatcher: mockMessageDispatcher,
+ ReCAPTCHAValidator: mockReCAPTCHAValidator,
+ }.ServeHTTP)
- resp := w.Result()
- respBody, err := io.ReadAll(resp.Body)
- require.NoError(t, err)
+ reqBody, err := json.Marshal(tc.receiverSendOTPRequest)
+ require.NoError(t, err)
+ req, err := http.NewRequestWithContext(tc.context, http.MethodPost, "/wallet-registration/otp", strings.NewReader(string(reqBody)))
+ require.NoError(t, err)
+ rr := httptest.NewRecorder()
- wantsBody := `
- {
- "error": "Cannot validate reCAPTCHA token"
+ r.ServeHTTP(rr, req)
+
+ resp := rr.Result()
+ defer resp.Body.Close()
+ respBody, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ assert.Equal(t, tc.wantStatusCode, resp.StatusCode)
+ assert.JSONEq(t, tc.wantBody, string(respBody))
+ })
+ }
+}
+
+func Test_ReceiverSendOTPHandler_ServeHTTP_otpHandlerIsCalled(t *testing.T) {
+ dbt := dbtest.Open(t)
+ defer dbt.Close()
+ dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN)
+ require.NoError(t, err)
+ defer dbConnectionPool.Close()
+
+ ctx := context.Background()
+ models, err := data.NewModels(dbConnectionPool)
+ const phoneNumber = "+14155550000"
+ const email = "foobar@test.com"
+ require.NoError(t, err)
+ wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "testWallet", "https://correct.test", "correct.test", "wallet123://")
+
+ validClaims := &anchorplatform.SEP24JWTClaims{
+ ClientDomainClaim: wallet.SEP10ClientDomain,
+ RegisteredClaims: jwt.RegisteredClaims{
+ ID: "test-transaction-id",
+ Subject: "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444",
+ ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)),
+ },
+ }
+ ctxWithValidSEP24Claims := context.WithValue(ctx, anchorplatform.SEP24ClaimsContextKey, validClaims)
+
+ const reCAPTCHAToken = "XyZ"
+
+ type testCase struct {
+ name string
+ receiverSendOTPRequest ReceiverSendOTPRequest
+ verificationField data.VerificationType
+ contactType data.ReceiverContactType
+ prepareMocksFn func(t *testing.T, mockReCAPTCHAValidator *validators.ReCAPTCHAValidatorMock, mockMessageDispatcher *message.MockMessageDispatcher)
+ shouldCreateObjects bool
+ assertLogsFn func(t *testing.T, contactType data.ReceiverContactType, r data.Receiver, entries []logrus.Entry)
+ wantStatusCode int
+ wantBody string
+ }
+ testCases := []testCase{}
+
+ for _, contactType := range data.GetAllReceiverContactTypes() {
+ for _, verificationField := range data.GetAllVerificationTypes() {
+ receiverSendOTPRequest := ReceiverSendOTPRequest{ReCAPTCHAToken: reCAPTCHAToken}
+ var contactInfo string
+ var messengerType message.MessengerType
+ switch contactType {
+ case data.ReceiverContactTypeSMS:
+ receiverSendOTPRequest.PhoneNumber = phoneNumber
+ contactInfo = phoneNumber
+ messengerType = message.MessengerTypeTwilioSMS
+ case data.ReceiverContactTypeEmail:
+ receiverSendOTPRequest.Email = email
+ contactInfo = email
+ messengerType = message.MessengerTypeAWSEmail
}
- `
- assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
- assert.JSONEq(t, wantsBody, string(respBody))
- })
-
- t.Run("returns 200 (DoB) - InternalServerError if phone number is not associated with receiver verification", func(t *testing.T) {
- requestSendOTP := ReceiverSendOTPRequest{
- PhoneNumber: "+14152223333",
- ReCAPTCHAToken: "XyZ",
+ truncatedContactInfo := utils.TruncateString(contactInfo, 3)
+
+ testCases = append(testCases, []testCase{
+ {
+ name: fmt.Sprintf("%s/%s/🔴 (500-InternalServerError) when the SMS dispatcher fails", contactType, verificationField),
+ receiverSendOTPRequest: receiverSendOTPRequest,
+ verificationField: verificationField,
+ contactType: contactType,
+ shouldCreateObjects: true,
+ prepareMocksFn: func(t *testing.T, mockReCAPTCHAValidator *validators.ReCAPTCHAValidatorMock, mockMessageDispatcher *message.MockMessageDispatcher) {
+ mockReCAPTCHAValidator.
+ On("IsTokenValid", mock.Anything, reCAPTCHAToken).
+ Return(true, nil).
+ Once()
+ mockMessageDispatcher.
+ On("SendMessage",
+ mock.Anything,
+ mock.AnythingOfType("message.Message"),
+ []message.MessageChannel{message.MessageChannelSMS, message.MessageChannelEmail}).
+ Return(messengerType, errors.New("failed calling message dispatcher")).
+ Once().
+ Run(func(args mock.Arguments) {
+ msg := args.Get(1).(message.Message)
+ assert.Contains(t, msg.Body, "is your MyCustomAid verification code.")
+ assert.Regexp(t, regexp.MustCompile(`^\d{6}\s.+$`), msg.Body)
+ })
+ },
+ assertLogsFn: func(t *testing.T, contactType data.ReceiverContactType, r data.Receiver, entries []logrus.Entry) {
+ contactTypeStr := utils.Humanize(string(contactType))
+ wantLog := fmt.Sprintf("sending OTP message to %s %s", contactTypeStr, truncatedContactInfo)
+ assert.Contains(t, entries[0].Message, wantLog)
+ },
+ wantStatusCode: http.StatusInternalServerError,
+ wantBody: fmt.Sprintf(`{"error":"Failed to send OTP message, reason: sending OTP message: cannot send OTP message through %s to %s: failed calling message dispatcher"}`, utils.Humanize(string(contactType)), truncatedContactInfo),
+ },
+ {
+ name: fmt.Sprintf("%s/%s/🟡 (200-Ok) with false positive", contactType, verificationField),
+ receiverSendOTPRequest: receiverSendOTPRequest,
+ verificationField: verificationField,
+ contactType: contactType,
+ shouldCreateObjects: false,
+ prepareMocksFn: func(t *testing.T, mockReCAPTCHAValidator *validators.ReCAPTCHAValidatorMock, mockMessageDispatcher *message.MockMessageDispatcher) {
+ mockReCAPTCHAValidator.
+ On("IsTokenValid", mock.Anything, reCAPTCHAToken).
+ Return(true, nil).
+ Once()
+ },
+ assertLogsFn: func(t *testing.T, contactType data.ReceiverContactType, r data.Receiver, entries []logrus.Entry) {
+ contactTypeStr := utils.Humanize(string(contactType))
+ wantLog := fmt.Sprintf("Could not find ANY receiver verification for %s %s: %v", contactTypeStr, truncatedContactInfo, data.ErrRecordNotFound)
+ assert.Contains(t, entries[0].Message, wantLog)
+ },
+ wantStatusCode: http.StatusOK,
+ wantBody: fmt.Sprintf(`{"message":"if your %s is registered, you'll receive an OTP","verification_field":"DATE_OF_BIRTH"}`, utils.Humanize(string(contactType))),
+ },
+ {
+ name: fmt.Sprintf("%s/%s/🟢 (200-Ok) OTP sent!", contactType, verificationField),
+ receiverSendOTPRequest: receiverSendOTPRequest,
+ verificationField: verificationField,
+ contactType: contactType,
+ shouldCreateObjects: true,
+ prepareMocksFn: func(t *testing.T, mockReCAPTCHAValidator *validators.ReCAPTCHAValidatorMock, mockMessageDispatcher *message.MockMessageDispatcher) {
+ mockReCAPTCHAValidator.
+ On("IsTokenValid", mock.Anything, reCAPTCHAToken).
+ Return(true, nil).
+ Once()
+ mockMessageDispatcher.
+ On("SendMessage",
+ mock.Anything,
+ mock.AnythingOfType("message.Message"),
+ []message.MessageChannel{message.MessageChannelSMS, message.MessageChannelEmail}).
+ Return(messengerType, nil).
+ Once().
+ Run(func(args mock.Arguments) {
+ msg := args.Get(1).(message.Message)
+ assert.Contains(t, msg.Body, "is your MyCustomAid verification code.")
+ assert.Regexp(t, regexp.MustCompile(`^\d{6}\s.+$`), msg.Body)
+ })
+ },
+ wantStatusCode: http.StatusOK,
+ wantBody: fmt.Sprintf(`{"message":"if your %s is registered, you'll receive an OTP","verification_field":"%s"}`, utils.Humanize(string(contactType)), verificationField),
+ },
+ }...)
}
- reqBody, _ = json.Marshal(requestSendOTP)
-
- reCAPTCHAValidator.
- On("IsTokenValid", mock.Anything, "XyZ").
- Return(true, nil).
- Once()
- req, err := http.NewRequest(http.MethodPost, "/wallet-registration/otp", strings.NewReader(string(reqBody)))
- require.NoError(t, err)
-
- validClaims := &anchorplatform.SEP24JWTClaims{
- ClientDomainClaim: wallet1.SEP10ClientDomain,
- RegisteredClaims: jwt.RegisteredClaims{
- ID: "test-transaction-id",
- Subject: "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444",
- ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)),
- },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ defer data.DeleteAllReceiversFixtures(t, ctx, dbConnectionPool)
+ defer data.DeleteAllReceiverWalletsFixtures(t, ctx, dbConnectionPool)
+ defer data.DeleteAllReceiverVerificationFixtures(t, ctx, dbConnectionPool)
+ if tc.shouldCreateObjects {
+ receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{
+ PhoneNumber: tc.receiverSendOTPRequest.PhoneNumber,
+ Email: tc.receiverSendOTPRequest.Email,
+ })
+ data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.RegisteredReceiversWalletStatus)
+ data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{
+ ReceiverID: receiver.ID,
+ VerificationField: tc.verificationField,
+ })
+ }
+
+ mockReCAPTCHAValidator := validators.NewReCAPTCHAValidatorMock(t)
+ mockMessageDispatcher := message.NewMockMessageDispatcher(t)
+
+ tc.prepareMocksFn(t, mockReCAPTCHAValidator, mockMessageDispatcher)
+
+ r := chi.NewRouter()
+ r.Post("/wallet-registration/otp", ReceiverSendOTPHandler{
+ Models: models,
+ MessageDispatcher: mockMessageDispatcher,
+ ReCAPTCHAValidator: mockReCAPTCHAValidator,
+ }.ServeHTTP)
+
+ reqBody, err := json.Marshal(tc.receiverSendOTPRequest)
+ require.NoError(t, err)
+ req, err := http.NewRequestWithContext(ctxWithValidSEP24Claims, http.MethodPost, "/wallet-registration/otp", strings.NewReader(string(reqBody)))
+ require.NoError(t, err)
+ rr := httptest.NewRecorder()
+
+ getEntries := log.DefaultLogger.StartTest(logrus.DebugLevel)
+ r.ServeHTTP(rr, req)
+
+ resp := rr.Result()
+ defer resp.Body.Close()
+ respBody, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ assert.Equal(t, tc.wantStatusCode, resp.StatusCode)
+ assert.JSONEq(t, tc.wantBody, string(respBody))
+ entries := getEntries()
+ if tc.assertLogsFn != nil {
+ tc.assertLogsFn(t, tc.contactType, data.Receiver{}, entries)
+ }
+ })
+ }
+}
+
+func Test_newReceiverSendOTPResponseBody(t *testing.T) {
+ for _, otpType := range data.GetAllReceiverContactTypes() {
+ for _, verificationType := range data.GetAllVerificationTypes() {
+ t.Run(fmt.Sprintf("%s/%s", otpType, verificationType), func(t *testing.T) {
+ gotBody := newReceiverSendOTPResponseBody(otpType, verificationType)
+ wantBody := ReceiverSendOTPResponseBody{
+ Message: fmt.Sprintf("if your %s is registered, you'll receive an OTP", utils.Humanize(string(otpType))),
+ VerificationField: verificationType,
+ }
+ require.Equal(t, wantBody, gotBody)
+ })
}
- req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, validClaims))
-
- rr := httptest.NewRecorder()
- r.ServeHTTP(rr, req)
-
- resp := rr.Result()
- respBody, err := io.ReadAll(resp.Body)
- require.NoError(t, err)
-
- wantsBody := `{
- "message":"if your phone number is registered, you'll receive an OTP",
- "verification_field":"DATE_OF_BIRTH"
- }`
- assert.Equal(t, http.StatusOK, resp.StatusCode)
- assert.JSONEq(t, wantsBody, string(respBody))
- })
-
- t.Run("returns 400 - BadRequest when recaptcha token is invalid", func(t *testing.T) {
- reCAPTCHAValidator.
- On("IsTokenValid", mock.Anything, "XyZ").
- Return(false, nil).
- Once()
-
- req, err := http.NewRequest(http.MethodPost, "/wallet-registration/otp", strings.NewReader(string(reqBody)))
- require.NoError(t, err)
-
- validClaims := &anchorplatform.SEP24JWTClaims{
- ClientDomainClaim: wallet1.SEP10ClientDomain,
- RegisteredClaims: jwt.RegisteredClaims{
- ID: "test-transaction-id",
- Subject: "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444",
- ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)),
- },
+ }
+}
+
+func Test_ReceiverSendOTPHandler_sendOTP(t *testing.T) {
+ dbt := dbtest.Open(t)
+ defer dbt.Close()
+ dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN)
+ require.NoError(t, err)
+ defer dbConnectionPool.Close()
+
+ models, err := data.NewModels(dbConnectionPool)
+ require.NoError(t, err)
+
+ ctx := context.Background()
+ organization, err := models.Organizations.Get(ctx)
+ require.NoError(t, err)
+ defaultOTPMessageTemplate := organization.OTPMessageTemplate
+
+ phoneNumber := "+380443973607"
+ email := "foobar@test.com"
+ otp := "246810"
+
+ testCases := []struct {
+ name string
+ overrideOrgOTPTemplate string
+ wantMessage string
+ shouldDispatcherFail bool
+ }{
+ {
+ name: "dispacher fails",
+ overrideOrgOTPTemplate: defaultOTPMessageTemplate,
+ wantMessage: fmt.Sprintf("246810 is your %s verification code. If you did not request this code, please ignore. Do not share your code with anyone.", organization.Name),
+ },
+ {
+ name: "🎉 successful with default message",
+ overrideOrgOTPTemplate: defaultOTPMessageTemplate,
+ wantMessage: fmt.Sprintf("246810 is your %s verification code. If you did not request this code, please ignore. Do not share your code with anyone.", organization.Name),
+ },
+ {
+ name: "🎉 successful with custom message and pre-existing OTP tag",
+ overrideOrgOTPTemplate: "Here's your code: {{.OTP}}.",
+ wantMessage: "Here's your code: 246810. If you did not request this code, please ignore. Do not share your code with anyone.",
+ },
+ {
+ name: "🎉 successful with custom message and NO pre-existing OTP tag",
+ overrideOrgOTPTemplate: "is your one-time password.",
+ wantMessage: "246810 is your one-time password. If you did not request this code, please ignore. Do not share your code with anyone.",
+ },
+ }
+
+ for _, contactType := range data.GetAllReceiverContactTypes() {
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("%s/%s", contactType, tc.name), func(t *testing.T) {
+ var expectedMsg message.Message
+ var contactInfo string
+ var messengerType message.MessengerType
+ switch contactType {
+ case data.ReceiverContactTypeSMS:
+ expectedMsg = message.Message{ToPhoneNumber: phoneNumber, Body: tc.wantMessage}
+ contactInfo = phoneNumber
+ messengerType = message.MessengerTypeTwilioSMS
+ case data.ReceiverContactTypeEmail:
+ expectedMsg = message.Message{ToEmail: email, Body: tc.wantMessage, Title: "Your One-Time Password: " + otp}
+ contactInfo = email
+ messengerType = message.MessengerTypeAWSEmail
+ }
+
+ mockMessageDispatcher := message.NewMockMessageDispatcher(t)
+ mockCall := mockMessageDispatcher.
+ On("SendMessage",
+ mock.Anything,
+ expectedMsg,
+ []message.MessageChannel{message.MessageChannelSMS, message.MessageChannelEmail})
+ if !tc.shouldDispatcherFail {
+ mockCall.Return(messengerType, nil).Once()
+ } else {
+ mockCall.Return(messengerType, errors.New("error sending message")).Once()
+ }
+
+ handler := ReceiverSendOTPHandler{
+ Models: models,
+ MessageDispatcher: mockMessageDispatcher,
+ }
+
+ err = models.Organizations.Update(ctx, &data.OrganizationUpdate{
+ OTPMessageTemplate: &tc.overrideOrgOTPTemplate,
+ })
+ require.NoError(t, err)
+
+ err := handler.sendOTP(ctx, contactType, contactInfo, otp)
+ require.NoError(t, err)
+ })
}
- req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, validClaims))
+ }
+}
- w := httptest.NewRecorder()
- r.ServeHTTP(w, req)
+func Test_ReceiverSendOTPHandler_handleOTPForReceiver(t *testing.T) {
+ dbt := dbtest.Open(t)
+ defer dbt.Close()
+ dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN)
+ require.NoError(t, err)
+ defer dbConnectionPool.Close()
- resp := w.Result()
- respBody, err := io.ReadAll(resp.Body)
- require.NoError(t, err)
+ models, err := data.NewModels(dbConnectionPool)
+ require.NoError(t, err)
- wantsBody := `
- {
- "error": "reCAPTCHA token is invalid"
- }
- `
- assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
- assert.JSONEq(t, wantsBody, string(respBody))
- })
+ ctx := context.Background()
+ wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "testWallet", "https://correct.test", "correct.test", "wallet123://")
+ receiverWithoutWalletInsert := &data.Receiver{
+ PhoneNumber: "+141555550000",
+ Email: "without_wallet@test.com",
+ }
- mockMessenger.AssertExpectations(t)
- reCAPTCHAValidator.AssertExpectations(t)
+ testCases := []struct {
+ name string
+ contactInfo func(r data.Receiver, contactType data.ReceiverContactType) string
+ dateOfBirth string
+ sep24ClientDomain string
+ prepareMocksFn func(t *testing.T, mockMessageDispatcher *message.MockMessageDispatcher)
+ assertLogsFn func(t *testing.T, contactType data.ReceiverContactType, r data.Receiver, entries []logrus.Entry)
+ wantVerificationField data.VerificationType
+ wantHttpErr func(contactType data.ReceiverContactType, r data.Receiver) *httperror.HTTPError
+ }{
+ {
+ name: "🟡 false positive if GetLatestByContactInfo returns no results",
+ contactInfo: func(r data.Receiver, contactType data.ReceiverContactType) string {
+ return "not_found"
+ },
+ assertLogsFn: func(t *testing.T, contactType data.ReceiverContactType, r data.Receiver, entries []logrus.Entry) {
+ contactTypeStr := utils.Humanize(string(contactType))
+ truncatedContactInfo := utils.TruncateString("not_found", 3)
+ wantLog := fmt.Sprintf("Could not find ANY receiver verification for %s %s: %v", contactTypeStr, truncatedContactInfo, data.ErrRecordNotFound)
+ assert.Contains(t, entries[0].Message, wantLog)
+ },
+ wantVerificationField: data.VerificationTypeDateOfBirth,
+ },
+ {
+ name: "🟡 false positive if UpdateOTPByReceiverContactInfoAndWalletDomain doesn't find a {
,client_domain} match (client_domain)",
+ contactInfo: func(r data.Receiver, contactType data.ReceiverContactType) string {
+ return r.ContactByType(contactType)
+ },
+ sep24ClientDomain: "incorrect.test",
+ assertLogsFn: func(t *testing.T, contactType data.ReceiverContactType, r data.Receiver, entries []logrus.Entry) {
+ contactTypeStr := utils.Humanize(string(contactType))
+ truncatedContactInfo := utils.TruncateString(r.ContactByType(contactType), 3)
+ wantLog := fmt.Sprintf("Could not find a match between %s (%s) and client domain (%s)", contactTypeStr, truncatedContactInfo, "incorrect.test")
+ assert.Contains(t, entries[0].Message, wantLog)
+ },
+ wantVerificationField: data.VerificationTypeDateOfBirth,
+ },
+ {
+ name: "🟡 false positive if UpdateOTPByReceiverContactInfoAndWalletDomain doesn't find a {,client_domain} match ()",
+ contactInfo: func(_ data.Receiver, contactType data.ReceiverContactType) string {
+ return receiverWithoutWalletInsert.ContactByType(contactType)
+ },
+ sep24ClientDomain: "correct.test",
+ assertLogsFn: func(t *testing.T, contactType data.ReceiverContactType, _ data.Receiver, entries []logrus.Entry) {
+ contactTypeStr := utils.Humanize(string(contactType))
+ truncatedContactInfo := utils.TruncateString(receiverWithoutWalletInsert.ContactByType(contactType), 3)
+ wantLog := fmt.Sprintf("Could not find a match between %s (%s) and client domain (%s)", contactTypeStr, truncatedContactInfo, "correct.test")
+ assert.Contains(t, entries[0].Message, wantLog)
+ },
+ wantVerificationField: data.VerificationTypeDateOfBirth,
+ },
+ {
+ name: "🔴 error if sendOTP fails",
+ contactInfo: func(r data.Receiver, contactType data.ReceiverContactType) string {
+ return r.ContactByType(contactType)
+ },
+ sep24ClientDomain: "correct.test",
+ prepareMocksFn: func(t *testing.T, mockMessageDispatcher *message.MockMessageDispatcher) {
+ mockMessageDispatcher.
+ On("SendMessage",
+ mock.Anything,
+ mock.AnythingOfType("message.Message"),
+ []message.MessageChannel{message.MessageChannelSMS, message.MessageChannelEmail}).
+ Return(message.MessengerTypeTwilioSMS, errors.New("error sending message")).
+ Once()
+ },
+ wantVerificationField: data.VerificationTypeDateOfBirth,
+ wantHttpErr: func(contactType data.ReceiverContactType, r data.Receiver) *httperror.HTTPError {
+ contactTypeStr := utils.Humanize(string(contactType))
+ truncatedContactInfo := utils.TruncateString(r.ContactByType(contactType), 3)
+ err := fmt.Errorf("sending OTP message: %w", fmt.Errorf("cannot send OTP message through %s to %s: %w", contactTypeStr, truncatedContactInfo, errors.New("error sending message")))
+ return httperror.InternalError(ctx, "Failed to send OTP message, reason: "+err.Error(), err, nil)
+ },
+ },
+ {
+ name: "🟢 successful",
+ contactInfo: func(r data.Receiver, contactType data.ReceiverContactType) string {
+ return r.ContactByType(contactType)
+ },
+ sep24ClientDomain: "correct.test",
+ prepareMocksFn: func(t *testing.T, mockMessageDispatcher *message.MockMessageDispatcher) {
+ mockMessageDispatcher.
+ On("SendMessage",
+ mock.Anything,
+ mock.AnythingOfType("message.Message"),
+ []message.MessageChannel{message.MessageChannelSMS, message.MessageChannelEmail}).
+ Return(message.MessengerTypeTwilioSMS, nil).
+ Once()
+ },
+ wantVerificationField: data.VerificationTypePin,
+ wantHttpErr: nil,
+ },
+ }
+
+ for _, contactType := range data.GetAllReceiverContactTypes() {
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("%s/%s", contactType, tc.name), func(t *testing.T) {
+ receiverWithWalletInsert := &data.Receiver{}
+ switch contactType {
+ case data.ReceiverContactTypeSMS:
+ receiverWithWalletInsert.PhoneNumber = "+141555551111"
+ case data.ReceiverContactTypeEmail:
+ receiverWithWalletInsert.Email = "with_wallet@test.com"
+ }
+
+ defer data.DeleteAllReceiversFixtures(t, ctx, dbConnectionPool)
+ defer data.DeleteAllReceiverWalletsFixtures(t, ctx, dbConnectionPool)
+ defer data.DeleteAllReceiverVerificationFixtures(t, ctx, dbConnectionPool)
+
+ handler := ReceiverSendOTPHandler{Models: models}
+ if tc.prepareMocksFn != nil {
+ mockMessageDispatcher := message.NewMockMessageDispatcher(t)
+ tc.prepareMocksFn(t, mockMessageDispatcher)
+ handler.MessageDispatcher = mockMessageDispatcher
+ }
+
+ // Setup receiver with Verification but without wallet:
+ receiverWithoutWallet := data.CreateReceiverFixture(t, ctx, dbConnectionPool, receiverWithoutWalletInsert)
+ _ = data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{
+ ReceiverID: receiverWithoutWallet.ID,
+ VerificationField: data.VerificationTypePin,
+ VerificationValue: "123456",
+ })
+
+ // Setup receiver with Verification AND wallet:
+ receiverWithWallet := data.CreateReceiverFixture(t, ctx, dbConnectionPool, receiverWithWalletInsert)
+ _ = data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{
+ ReceiverID: receiverWithWallet.ID,
+ VerificationField: data.VerificationTypePin,
+ VerificationValue: "123456",
+ })
+ _ = data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiverWithWallet.ID, wallet.ID, data.RegisteredReceiversWalletStatus)
+
+ getEntries := log.DefaultLogger.StartTest(logrus.DebugLevel)
+
+ contactInfo := tc.contactInfo(*receiverWithWallet, contactType)
+ verificationField, httpErr := handler.handleOTPForReceiver(ctx, contactType, contactInfo, tc.sep24ClientDomain)
+ if tc.wantHttpErr != nil {
+ wantHTTPErr := tc.wantHttpErr(contactType, *receiverWithWallet)
+ require.NotNil(t, httpErr)
+ assert.Equal(t, *wantHTTPErr, *httpErr)
+ assert.Equal(t, tc.wantVerificationField, verificationField)
+ } else {
+ require.Nil(t, httpErr)
+ assert.Equal(t, tc.wantVerificationField, verificationField)
+ }
+
+ entries := getEntries()
+ if tc.assertLogsFn != nil {
+ tc.assertLogsFn(t, contactType, *receiverWithWallet, entries)
+ }
+ })
+ }
+ }
}
diff --git a/internal/serve/httphandler/receiver_wallets_handler.go b/internal/serve/httphandler/receiver_wallets_handler.go
index 527ccf708..f2ed32768 100644
--- a/internal/serve/httphandler/receiver_wallets_handler.go
+++ b/internal/serve/httphandler/receiver_wallets_handler.go
@@ -18,7 +18,7 @@ import (
"github.com/stellar/stellar-disbursement-platform-backend/stellar-multitenant/pkg/tenant"
)
-type RetryInvitationSMSResponse struct {
+type RetryInvitationMessageResponse struct {
ID string `json:"id"`
ReceiverID string `json:"receiver_id"`
WalletID string `json:"wallet_id"`
@@ -39,13 +39,13 @@ func (h ReceiverWalletsHandler) RetryInvitation(rw http.ResponseWriter, req *htt
var msg *events.Message
receiverWallet, err := db.RunInTransactionWithResult(ctx, h.Models.DBConnectionPool, nil, func(dbTx db.DBTransaction) (*data.ReceiverWallet, error) {
- receiverWallet, err := h.Models.ReceiverWallet.RetryInvitationSMS(ctx, dbTx, receiverWalletID)
+ receiverWallet, err := h.Models.ReceiverWallet.RetryInvitationMessage(ctx, dbTx, receiverWalletID)
if err != nil {
- return nil, fmt.Errorf("retrying invitation SMS for receiver wallet ID %s: %w", receiverWalletID, err)
+ return nil, fmt.Errorf("retrying invitation message for receiver wallet ID %s: %w", receiverWalletID, err)
}
- eventData := []schemas.EventReceiverWalletSMSInvitationData{{ReceiverWalletID: receiverWalletID}}
- msg, err = events.NewMessage(ctx, events.ReceiverWalletNewInvitationTopic, receiverWalletID, events.RetryReceiverWalletSMSInvitationType, eventData)
+ eventData := []schemas.EventReceiverWalletInvitationData{{ReceiverWalletID: receiverWalletID}}
+ msg, err = events.NewMessage(ctx, events.ReceiverWalletNewInvitationTopic, receiverWalletID, events.RetryReceiverWalletInvitationType, eventData)
if err != nil {
return nil, fmt.Errorf("creating event producer message: %w", err)
}
@@ -77,7 +77,7 @@ func (h ReceiverWalletsHandler) RetryInvitation(rw http.ResponseWriter, req *htt
}
}
- response := RetryInvitationSMSResponse{
+ response := RetryInvitationMessageResponse{
ID: receiverWallet.ID,
ReceiverID: receiverWallet.Receiver.ID,
WalletID: receiverWallet.Wallet.ID,
diff --git a/internal/serve/httphandler/receiver_wallets_handler_test.go b/internal/serve/httphandler/receiver_wallets_handler_test.go
index 40592d814..15a556b8b 100644
--- a/internal/serve/httphandler/receiver_wallets_handler_test.go
+++ b/internal/serve/httphandler/receiver_wallets_handler_test.go
@@ -94,8 +94,8 @@ func Test_RetryInvitation(t *testing.T) {
Topic: events.ReceiverWalletNewInvitationTopic,
Key: rw.ID,
TenantID: tnt.ID,
- Type: events.RetryReceiverWalletSMSInvitationType,
- Data: []schemas.EventReceiverWalletSMSInvitationData{
+ Type: events.RetryReceiverWalletInvitationType,
+ Data: []schemas.EventReceiverWalletInvitationData{
{
ReceiverWalletID: rw.ID,
},
@@ -147,8 +147,8 @@ func Test_RetryInvitation(t *testing.T) {
Topic: events.ReceiverWalletNewInvitationTopic,
Key: rw.ID,
TenantID: tnt.ID,
- Type: events.RetryReceiverWalletSMSInvitationType,
- Data: []schemas.EventReceiverWalletSMSInvitationData{
+ Type: events.RetryReceiverWalletInvitationType,
+ Data: []schemas.EventReceiverWalletInvitationData{
{
ReceiverWalletID: rw.ID,
},
@@ -220,8 +220,8 @@ func Test_RetryInvitation(t *testing.T) {
Topic: events.ReceiverWalletNewInvitationTopic,
Key: rw.ID,
TenantID: tnt.ID,
- Type: events.RetryReceiverWalletSMSInvitationType,
- Data: []schemas.EventReceiverWalletSMSInvitationData{
+ Type: events.RetryReceiverWalletInvitationType,
+ Data: []schemas.EventReceiverWalletInvitationData{
{ReceiverWalletID: rw.ID},
},
}
diff --git a/internal/serve/httphandler/registration_contact_types.go b/internal/serve/httphandler/registration_contact_types.go
new file mode 100644
index 000000000..1e75ebb37
--- /dev/null
+++ b/internal/serve/httphandler/registration_contact_types.go
@@ -0,0 +1,15 @@
+package httphandler
+
+import (
+ "net/http"
+
+ "github.com/stellar/go/support/render/httpjson"
+
+ "github.com/stellar/stellar-disbursement-platform-backend/internal/data"
+)
+
+type RegistrationContactTypesHandler struct{}
+
+func (c RegistrationContactTypesHandler) Get(w http.ResponseWriter, r *http.Request) {
+ httpjson.Render(w, data.AllRegistrationContactTypes(), httpjson.JSON)
+}
diff --git a/internal/serve/httphandler/registration_contact_types_test.go b/internal/serve/httphandler/registration_contact_types_test.go
new file mode 100644
index 000000000..2800037b9
--- /dev/null
+++ b/internal/serve/httphandler/registration_contact_types_test.go
@@ -0,0 +1,32 @@
+package httphandler
+
+import (
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func Test_RegistrationContactTypesHandler_Get(t *testing.T) {
+ h := RegistrationContactTypesHandler{}
+
+ rr := httptest.NewRecorder()
+ req, err := http.NewRequest("GET", "/receiver-contact-types", nil)
+ require.NoError(t, err)
+ http.HandlerFunc(h.Get).ServeHTTP(rr, req)
+ resp := rr.Result()
+ respBody, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ expectedJSON := `[
+ "EMAIL",
+ "PHONE_NUMBER",
+ "EMAIL_AND_WALLET_ADDRESS",
+ "PHONE_NUMBER_AND_WALLET_ADDRESS"
+ ]`
+ assert.JSONEq(t, expectedJSON, string(respBody))
+}
diff --git a/internal/serve/httphandler/statistics_handler_test.go b/internal/serve/httphandler/statistics_handler_test.go
index dd7bc83fa..4fff57490 100644
--- a/internal/serve/httphandler/statistics_handler_test.go
+++ b/internal/serve/httphandler/statistics_handler_test.go
@@ -88,15 +88,13 @@ func TestStatisticsHandler(t *testing.T) {
require.NoError(t, err)
asset1 := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV")
- country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "FRA", "France")
wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "wallet1", "https://www.wallet.com", "www.wallet.com", "wallet1://")
disbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Name: "disbursement 1",
- Status: data.CompletedDisbursementStatus,
- Asset: asset1,
- Wallet: wallet,
- Country: country,
+ Name: "disbursement 1",
+ Status: data.CompletedDisbursementStatus,
+ Asset: asset1,
+ Wallet: wallet,
})
t.Run("get statistics for existing disbursement with no data", func(t *testing.T) {
diff --git a/internal/serve/httphandler/update_receiver_handler.go b/internal/serve/httphandler/update_receiver_handler.go
index d9eb52c58..d0dae4a7f 100644
--- a/internal/serve/httphandler/update_receiver_handler.go
+++ b/internal/serve/httphandler/update_receiver_handler.go
@@ -4,8 +4,11 @@ import (
"errors"
"fmt"
"net/http"
+ "slices"
+ "strings"
"github.com/go-chi/chi/v5"
+ "github.com/lib/pq"
"github.com/stellar/go/support/http/httpdecode"
"github.com/stellar/go/support/log"
"github.com/stellar/go/support/render/httpjson"
@@ -23,7 +26,7 @@ type UpdateReceiverHandler struct {
func createVerificationInsert(updateReceiverInfo *validators.UpdateReceiverRequest, receiverID string) []data.ReceiverVerificationInsert {
receiverVerifications := []data.ReceiverVerificationInsert{}
- appendNewVerificationValue := func(verificationField data.VerificationField, verificationValue string) {
+ appendNewVerificationValue := func(verificationField data.VerificationType, verificationValue string) {
if verificationValue != "" {
receiverVerifications = append(receiverVerifications, data.ReceiverVerificationInsert{
ReceiverID: receiverID,
@@ -33,15 +36,15 @@ func createVerificationInsert(updateReceiverInfo *validators.UpdateReceiverReque
}
}
- for _, verificationField := range data.GetAllVerificationFields() {
+ for _, verificationField := range data.GetAllVerificationTypes() {
switch verificationField {
- case data.VerificationFieldDateOfBirth:
+ case data.VerificationTypeDateOfBirth:
appendNewVerificationValue(verificationField, updateReceiverInfo.DateOfBirth)
- case data.VerificationFieldYearMonth:
+ case data.VerificationTypeYearMonth:
appendNewVerificationValue(verificationField, updateReceiverInfo.YearMonth)
- case data.VerificationFieldPin:
+ case data.VerificationTypePin:
appendNewVerificationValue(verificationField, updateReceiverInfo.Pin)
- case data.VerificationFieldNationalID:
+ case data.VerificationTypeNationalID:
appendNewVerificationValue(verificationField, updateReceiverInfo.NationalID)
}
}
@@ -85,7 +88,7 @@ func (h UpdateReceiverHandler) UpdateReceiver(rw http.ResponseWriter, req *http.
receiver, err := db.RunInTransactionWithResult(ctx, h.DBConnectionPool, nil, func(dbTx db.DBTransaction) (response *data.Receiver, innerErr error) {
for _, rv := range receiverVerifications {
innerErr = h.Models.ReceiverVerification.UpsertVerificationValue(
- req.Context(),
+ ctx,
dbTx,
rv.ReceiverID,
rv.VerificationField,
@@ -93,31 +96,61 @@ func (h UpdateReceiverHandler) UpdateReceiver(rw http.ResponseWriter, req *http.
)
if innerErr != nil {
- return nil, fmt.Errorf("error updating receiver verification %s: %w", rv.VerificationField, innerErr)
+ return nil, fmt.Errorf("updating receiver verification %s: %w", rv.VerificationField, innerErr)
}
}
- receiverUpdate := data.ReceiverUpdate{
- Email: reqBody.Email,
- ExternalId: reqBody.ExternalID,
+ var receiverUpdate data.ReceiverUpdate
+ if reqBody.Email != "" {
+ receiverUpdate.Email = &reqBody.Email
}
- if receiverUpdate.Email != "" || receiverUpdate.ExternalId != "" {
+ if reqBody.PhoneNumber != "" {
+ receiverUpdate.PhoneNumber = &reqBody.PhoneNumber
+ }
+ if reqBody.ExternalID != "" {
+ receiverUpdate.ExternalId = &reqBody.ExternalID
+ }
+
+ if !receiverUpdate.IsEmpty() {
if innerErr = h.Models.Receiver.Update(ctx, dbTx, receiverID, receiverUpdate); innerErr != nil {
- return nil, fmt.Errorf("error updating receiver with ID %s: %w", receiverID, innerErr)
+ return nil, fmt.Errorf("updating receiver with ID %s: %w", receiverID, innerErr)
}
}
receiver, innerErr := h.Models.Receiver.Get(ctx, dbTx, receiverID)
if innerErr != nil {
- return nil, fmt.Errorf("error querying receiver with ID %s: %w", receiverID, innerErr)
+ return nil, fmt.Errorf("querying receiver with ID %s: %w", receiverID, innerErr)
}
return receiver, nil
})
if err != nil {
+ if httpErr := parseHttpConflictErrorIfNeeded(err); httpErr != nil {
+ httpErr.Render(rw)
+ return
+ }
+
httperror.InternalError(ctx, "", err, nil).Render(rw)
return
}
httpjson.Render(rw, receiver, httpjson.JSON)
}
+
+func parseHttpConflictErrorIfNeeded(err error) *httperror.HTTPError {
+ var pqErr *pq.Error
+ if err == nil || !errors.As(err, &pqErr) || pqErr.Code != "23505" {
+ return nil
+ }
+
+ allowedConstraints := []string{"receiver_unique_email", "receiver_unique_phone_number"}
+ if !slices.Contains(allowedConstraints, pqErr.Constraint) {
+ return nil
+ }
+ fieldName := strings.Replace(pqErr.Constraint, "receiver_unique_", "", 1)
+ msg := fmt.Sprintf("The provided %s is already associated with another user.", fieldName)
+
+ return httperror.Conflict(msg, err, map[string]interface{}{
+ fieldName: fieldName + " must be unique",
+ })
+}
diff --git a/internal/serve/httphandler/update_receiver_handler_test.go b/internal/serve/httphandler/update_receiver_handler_test.go
index 5a267d1a6..ff33fced8 100644
--- a/internal/serve/httphandler/update_receiver_handler_test.go
+++ b/internal/serve/httphandler/update_receiver_handler_test.go
@@ -26,25 +26,25 @@ func Test_UpdateReceiverHandler_createVerificationInsert(t *testing.T) {
verificationDOB := data.ReceiverVerificationInsert{
ReceiverID: receiverID,
- VerificationField: data.VerificationFieldDateOfBirth,
+ VerificationField: data.VerificationTypeDateOfBirth,
VerificationValue: "1999-01-01",
}
verificationYearMonth := data.ReceiverVerificationInsert{
ReceiverID: receiverID,
- VerificationField: data.VerificationFieldYearMonth,
+ VerificationField: data.VerificationTypeYearMonth,
VerificationValue: "1999-01",
}
verificationPIN := data.ReceiverVerificationInsert{
ReceiverID: receiverID,
- VerificationField: data.VerificationFieldPin,
+ VerificationField: data.VerificationTypePin,
VerificationValue: "123",
}
verificationNationalID := data.ReceiverVerificationInsert{
ReceiverID: receiverID,
- VerificationField: data.VerificationFieldNationalID,
+ VerificationField: data.VerificationTypeNationalID,
VerificationValue: "12345CODE",
}
@@ -98,10 +98,9 @@ func Test_UpdateReceiverHandler_createVerificationInsert(t *testing.T) {
}
}
-func Test_UpdateReceiverHandler(t *testing.T) {
+func Test_UpdateReceiverHandler_400(t *testing.T) {
dbt := dbtest.Open(t)
defer dbt.Close()
-
dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN)
require.NoError(t, err)
defer dbConnectionPool.Close()
@@ -115,26 +114,21 @@ func Test_UpdateReceiverHandler(t *testing.T) {
}
ctx := context.Background()
- receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{
- PhoneNumber: "+380445555555",
- Email: &[]string{"receiver@email.com"}[0],
- ExternalID: "externalID",
- })
+ receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, nil)
// setup
r := chi.NewRouter()
r.Patch("/receivers/{id}", handler.UpdateReceiver)
- t.Run("error invalid request body", func(t *testing.T) {
- testCases := []struct {
- name string
- request validators.UpdateReceiverRequest
- want string
- }{
- {
- name: "empty request body",
- request: validators.UpdateReceiverRequest{},
- want: `
+ testCases := []struct {
+ name string
+ request validators.UpdateReceiverRequest
+ expectedBody string
+ }{
+ {
+ name: "empty request body",
+ request: validators.UpdateReceiverRequest{},
+ expectedBody: `
{
"error": "request invalid",
"extras": {
@@ -142,11 +136,11 @@ func Test_UpdateReceiverHandler(t *testing.T) {
}
}
`,
- },
- {
- name: "invalid date of birth",
- request: validators.UpdateReceiverRequest{DateOfBirth: "invalid"},
- want: `
+ },
+ {
+ name: "invalid date of birth",
+ request: validators.UpdateReceiverRequest{DateOfBirth: "invalid"},
+ expectedBody: `
{
"error": "request invalid",
"extras": {
@@ -154,11 +148,11 @@ func Test_UpdateReceiverHandler(t *testing.T) {
}
}
`,
- },
- {
- name: "invalid year/month",
- request: validators.UpdateReceiverRequest{YearMonth: "invalid"},
- want: `
+ },
+ {
+ name: "invalid year/month",
+ request: validators.UpdateReceiverRequest{YearMonth: "invalid"},
+ expectedBody: `
{
"error": "request invalid",
"extras": {
@@ -166,11 +160,11 @@ func Test_UpdateReceiverHandler(t *testing.T) {
}
}
`,
- },
- {
- name: "invalid pin",
- request: validators.UpdateReceiverRequest{Pin: " "},
- want: `
+ },
+ {
+ name: "invalid pin",
+ request: validators.UpdateReceiverRequest{Pin: " "},
+ expectedBody: `
{
"error": "request invalid",
"extras": {
@@ -178,11 +172,11 @@ func Test_UpdateReceiverHandler(t *testing.T) {
}
}
`,
- },
- {
- name: "invalid national ID - empty",
- request: validators.UpdateReceiverRequest{NationalID: " "},
- want: `
+ },
+ {
+ name: "invalid national ID - empty",
+ request: validators.UpdateReceiverRequest{NationalID: " "},
+ expectedBody: `
{
"error": "request invalid",
"extras": {
@@ -190,11 +184,11 @@ func Test_UpdateReceiverHandler(t *testing.T) {
}
}
`,
- },
- {
- name: "invalid national ID - too long",
- request: validators.UpdateReceiverRequest{NationalID: fmt.Sprintf("%0*d", utils.VerificationFieldMaxIdLength+1, 0)},
- want: `
+ },
+ {
+ name: "invalid national ID - too long",
+ request: validators.UpdateReceiverRequest{NationalID: fmt.Sprintf("%0*d", utils.VerificationFieldMaxIdLength+1, 0)},
+ expectedBody: `
{
"error": "request invalid",
"extras": {
@@ -202,11 +196,11 @@ func Test_UpdateReceiverHandler(t *testing.T) {
}
}
`,
- },
- {
- name: "invalid email",
- request: validators.UpdateReceiverRequest{Email: "invalid"},
- want: `
+ },
+ {
+ name: "invalid email",
+ request: validators.UpdateReceiverRequest{Email: "invalid"},
+ expectedBody: `
{
"error": "request invalid",
"extras": {
@@ -214,451 +208,421 @@ func Test_UpdateReceiverHandler(t *testing.T) {
}
}
`,
- },
- {
- name: "invalid external ID",
- request: validators.UpdateReceiverRequest{ExternalID: " "},
- want: `
+ },
+ {
+ name: "invalid phone number",
+ request: validators.UpdateReceiverRequest{PhoneNumber: "invalid"},
+ expectedBody: `
{
"error": "request invalid",
"extras": {
- "external_id": "invalid external_id format"
+ "phone_number": "invalid phone number format"
}
}
`,
- },
- }
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- route := fmt.Sprintf("/receivers/%s", receiver.ID)
- reqBody, err := json.Marshal(tc.request)
- require.NoError(t, err)
- req, err := http.NewRequest("PATCH", route, strings.NewReader(string(reqBody)))
- require.NoError(t, err)
+ },
+ {
+ name: "invalid external ID",
+ request: validators.UpdateReceiverRequest{ExternalID: " "},
+ expectedBody: `
+ {
+ "error": "request invalid",
+ "extras": {
+ "external_id": "external_id cannot be set to empty"
+ }
+ }
+ `,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ route := fmt.Sprintf("/receivers/%s", receiver.ID)
+ reqBody, err := json.Marshal(tc.request)
+ require.NoError(t, err)
+ req, err := http.NewRequest("PATCH", route, strings.NewReader(string(reqBody)))
+ require.NoError(t, err)
- rr := httptest.NewRecorder()
- r.ServeHTTP(rr, req)
+ rr := httptest.NewRecorder()
+ r.ServeHTTP(rr, req)
- resp := rr.Result()
- respBody, err := io.ReadAll(resp.Body)
- require.NoError(t, err)
+ resp := rr.Result()
+ respBody, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
- assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
- assert.JSONEq(t, tc.want, string(respBody))
- })
- }
- })
+ assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
+ assert.JSONEq(t, tc.expectedBody, string(respBody))
+ })
+ }
+}
+
+func Test_UpdateReceiverHandler_404(t *testing.T) {
+ dbt := dbtest.Open(t)
+ defer dbt.Close()
+ dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN)
+ require.NoError(t, err)
+ defer dbConnectionPool.Close()
- t.Run("receiver not found", func(t *testing.T) {
- request := validators.UpdateReceiverRequest{DateOfBirth: "1999-01-01"}
+ models, err := data.NewModels(dbConnectionPool)
+ require.NoError(t, err)
- route := fmt.Sprintf("/receivers/%s", "invalid_receiver_id")
- reqBody, err := json.Marshal(request)
- require.NoError(t, err)
- req, err := http.NewRequest("PATCH", route, strings.NewReader(string(reqBody)))
- require.NoError(t, err)
+ handler := &UpdateReceiverHandler{
+ Models: models,
+ DBConnectionPool: dbConnectionPool,
+ }
- rr := httptest.NewRecorder()
- r.ServeHTTP(rr, req)
+ // setup
+ r := chi.NewRouter()
+ r.Patch("/receivers/{id}", handler.UpdateReceiver)
- resp := rr.Result()
- assert.Equal(t, http.StatusNotFound, resp.StatusCode)
- })
+ request := validators.UpdateReceiverRequest{DateOfBirth: "1999-01-01"}
- t.Run("update date of birth value", func(t *testing.T) {
- data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{
- ReceiverID: receiver.ID,
- VerificationField: data.VerificationFieldDateOfBirth,
- VerificationValue: "2000-01-01",
- })
+ route := fmt.Sprintf("/receivers/%s", "invalid_receiver_id")
+ reqBody, err := json.Marshal(request)
+ require.NoError(t, err)
+ req, err := http.NewRequest("PATCH", route, strings.NewReader(string(reqBody)))
+ require.NoError(t, err)
- request := validators.UpdateReceiverRequest{DateOfBirth: "1999-01-01"}
-
- route := fmt.Sprintf("/receivers/%s", receiver.ID)
- reqBody, err := json.Marshal(request)
- require.NoError(t, err)
- req, err := http.NewRequest("PATCH", route, strings.NewReader(string(reqBody)))
- require.NoError(t, err)
-
- rr := httptest.NewRecorder()
- r.ServeHTTP(rr, req)
-
- resp := rr.Result()
- assert.Equal(t, http.StatusOK, resp.StatusCode)
-
- query := `
- SELECT
- hashed_value
- FROM
- receiver_verifications
- WHERE
- receiver_id = $1 AND
- verification_field = $2
- `
-
- newReceiverVerification := data.ReceiverVerification{}
- err = dbConnectionPool.GetContext(ctx, &newReceiverVerification, query, receiver.ID, data.VerificationFieldDateOfBirth)
- require.NoError(t, err)
-
- assert.True(t, data.CompareVerificationValue(newReceiverVerification.HashedValue, "1999-01-01"))
- assert.False(t, data.CompareVerificationValue(newReceiverVerification.HashedValue, "2000-01-01"))
-
- receiverDB, err := models.Receiver.Get(ctx, dbConnectionPool, receiver.ID)
- require.NoError(t, err)
- assert.Equal(t, "receiver@email.com", *receiverDB.Email)
- assert.Equal(t, "externalID", receiverDB.ExternalID)
- })
+ rr := httptest.NewRecorder()
+ r.ServeHTTP(rr, req)
- t.Run("update year/month value", func(t *testing.T) {
- data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{
- ReceiverID: receiver.ID,
- VerificationField: data.VerificationFieldYearMonth,
- VerificationValue: "2000-01",
- })
+ resp := rr.Result()
+ assert.Equal(t, http.StatusNotFound, resp.StatusCode)
+}
- request := validators.UpdateReceiverRequest{YearMonth: "1999-01"}
-
- route := fmt.Sprintf("/receivers/%s", receiver.ID)
- reqBody, err := json.Marshal(request)
- require.NoError(t, err)
- req, err := http.NewRequest("PATCH", route, strings.NewReader(string(reqBody)))
- require.NoError(t, err)
-
- rr := httptest.NewRecorder()
- r.ServeHTTP(rr, req)
-
- resp := rr.Result()
- assert.Equal(t, http.StatusOK, resp.StatusCode)
-
- query := `
- SELECT
- hashed_value
- FROM
- receiver_verifications
- WHERE
- receiver_id = $1 AND
- verification_field = $2
- `
-
- newReceiverVerification := data.ReceiverVerification{}
- err = dbConnectionPool.GetContext(ctx, &newReceiverVerification, query, receiver.ID, data.VerificationFieldYearMonth)
- require.NoError(t, err)
-
- assert.True(t, data.CompareVerificationValue(newReceiverVerification.HashedValue, "1999-01"))
- assert.False(t, data.CompareVerificationValue(newReceiverVerification.HashedValue, "2000-01"))
-
- receiverDB, err := models.Receiver.Get(ctx, dbConnectionPool, receiver.ID)
- require.NoError(t, err)
- assert.Equal(t, "receiver@email.com", *receiverDB.Email)
- assert.Equal(t, "externalID", receiverDB.ExternalID)
- })
+func Test_UpdateReceiverHandler_409(t *testing.T) {
+ dbt := dbtest.Open(t)
+ defer dbt.Close()
+ dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN)
+ require.NoError(t, err)
+ defer dbConnectionPool.Close()
- t.Run("update pin value", func(t *testing.T) {
- data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{
- ReceiverID: receiver.ID,
- VerificationField: data.VerificationFieldPin,
- VerificationValue: "8901",
- })
+ models, err := data.NewModels(dbConnectionPool)
+ require.NoError(t, err)
- request := validators.UpdateReceiverRequest{Pin: "1234"}
-
- route := fmt.Sprintf("/receivers/%s", receiver.ID)
- reqBody, err := json.Marshal(request)
- require.NoError(t, err)
- req, err := http.NewRequest("PATCH", route, strings.NewReader(string(reqBody)))
- require.NoError(t, err)
-
- rr := httptest.NewRecorder()
- r.ServeHTTP(rr, req)
-
- resp := rr.Result()
- assert.Equal(t, http.StatusOK, resp.StatusCode)
-
- query := `
- SELECT
- hashed_value
- FROM
- receiver_verifications
- WHERE
- receiver_id = $1 AND
- verification_field = $2
- `
-
- newReceiverVerification := data.ReceiverVerification{}
- err = dbConnectionPool.GetContext(ctx, &newReceiverVerification, query, receiver.ID, data.VerificationFieldPin)
- require.NoError(t, err)
-
- assert.True(t, data.CompareVerificationValue(newReceiverVerification.HashedValue, "1234"))
- assert.False(t, data.CompareVerificationValue(newReceiverVerification.HashedValue, "8901"))
-
- receiverDB, err := models.Receiver.Get(ctx, dbConnectionPool, receiver.ID)
- require.NoError(t, err)
- assert.Equal(t, "receiver@email.com", *receiverDB.Email)
- assert.Equal(t, "externalID", receiverDB.ExternalID)
- })
+ handler := &UpdateReceiverHandler{
+ Models: models,
+ DBConnectionPool: dbConnectionPool,
+ }
- t.Run("update national ID value", func(t *testing.T) {
- data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{
- ReceiverID: receiver.ID,
- VerificationField: data.VerificationFieldNationalID,
- VerificationValue: "OLDID890",
- })
+ ctx := context.Background()
- request := validators.UpdateReceiverRequest{NationalID: "NEWID123"}
-
- route := fmt.Sprintf("/receivers/%s", receiver.ID)
- reqBody, err := json.Marshal(request)
- require.NoError(t, err)
- req, err := http.NewRequest("PATCH", route, strings.NewReader(string(reqBody)))
- require.NoError(t, err)
-
- rr := httptest.NewRecorder()
- r.ServeHTTP(rr, req)
-
- resp := rr.Result()
- assert.Equal(t, http.StatusOK, resp.StatusCode)
-
- query := `
- SELECT
- hashed_value
- FROM
- receiver_verifications
- WHERE
- receiver_id = $1 AND
- verification_field = $2
- `
-
- newReceiverVerification := data.ReceiverVerification{}
- err = dbConnectionPool.GetContext(ctx, &newReceiverVerification, query, receiver.ID, data.VerificationFieldNationalID)
- require.NoError(t, err)
-
- assert.True(t, data.CompareVerificationValue(newReceiverVerification.HashedValue, "NEWID123"))
- assert.False(t, data.CompareVerificationValue(newReceiverVerification.HashedValue, "OLDID890"))
-
- receiverDB, err := models.Receiver.Get(ctx, dbConnectionPool, receiver.ID)
- require.NoError(t, err)
- assert.Equal(t, "receiver@email.com", *receiverDB.Email)
- assert.Equal(t, "externalID", receiverDB.ExternalID)
+ // setup
+ r := chi.NewRouter()
+ r.Patch("/receivers/{id}", handler.UpdateReceiver)
+
+ receiverStatic := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{
+ PhoneNumber: "+14155556666",
})
+ receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, nil)
- t.Run("update multiples receiver verifications values", func(t *testing.T) {
- data.DeleteAllReceiverVerificationFixtures(t, ctx, dbConnectionPool)
+ testCases := []struct {
+ fieldName string
+ request validators.UpdateReceiverRequest
+ expectedBody string
+ }{
+ {
+ fieldName: "email conflict",
+ request: validators.UpdateReceiverRequest{
+ Email: receiverStatic.Email,
+ },
+ expectedBody: `{
+ "error": "The provided email is already associated with another user.",
+ "extras": {
+ "email": "email must be unique"
+ }
+ }`,
+ },
+ {
+ fieldName: "phone_number",
+ request: validators.UpdateReceiverRequest{
+ PhoneNumber: receiverStatic.PhoneNumber,
+ },
+ expectedBody: `{
+ "error": "The provided phone_number is already associated with another user.",
+ "extras": {
+ "phone_number": "phone_number must be unique"
+ }
+ }`,
+ },
+ }
- data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{
- ReceiverID: receiver.ID,
- VerificationField: data.VerificationFieldDateOfBirth,
- VerificationValue: "2000-01-01",
- })
+ for _, tc := range testCases {
+ t.Run(tc.fieldName, func(t *testing.T) {
+ route := fmt.Sprintf("/receivers/%s", receiver.ID)
+ reqBody, err := json.Marshal(tc.request)
+ require.NoError(t, err)
+ req, err := http.NewRequest(http.MethodPatch, route, strings.NewReader(string(reqBody)))
+ require.NoError(t, err)
- data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{
- ReceiverID: receiver.ID,
- VerificationField: data.VerificationFieldYearMonth,
- VerificationValue: "2000-01",
- })
+ rr := httptest.NewRecorder()
+ r.ServeHTTP(rr, req)
- data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{
- ReceiverID: receiver.ID,
- VerificationField: data.VerificationFieldPin,
- VerificationValue: "8901",
- })
+ resp := rr.Result()
+ respBody, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
- data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{
- ReceiverID: receiver.ID,
- VerificationField: data.VerificationFieldNationalID,
- VerificationValue: "OLDID890",
+ assert.Equal(t, http.StatusConflict, resp.StatusCode)
+ assert.JSONEq(t, tc.expectedBody, string(respBody))
})
+ }
+}
- request := validators.UpdateReceiverRequest{
- DateOfBirth: "1999-01-01",
- YearMonth: "1999-01",
- Pin: "1234",
- NationalID: "NEWID123",
- }
-
- route := fmt.Sprintf("/receivers/%s", receiver.ID)
- reqBody, err := json.Marshal(request)
- require.NoError(t, err)
- req, err := http.NewRequest("PATCH", route, strings.NewReader(string(reqBody)))
- require.NoError(t, err)
-
- rr := httptest.NewRecorder()
- r.ServeHTTP(rr, req)
-
- resp := rr.Result()
- assert.Equal(t, http.StatusOK, resp.StatusCode)
-
- query := `
- SELECT
- hashed_value
- FROM
- receiver_verifications
- WHERE
- receiver_id = $1 AND
- verification_field = $2
- `
-
- receiverVerifications := []struct {
- verificationField data.VerificationField
- newVerificationValue string
- oldVerificationValue string
- }{
- {
- verificationField: data.VerificationFieldDateOfBirth,
- newVerificationValue: "1999-01-01",
- oldVerificationValue: "2000-01-01",
- },
- {
- verificationField: data.VerificationFieldYearMonth,
- newVerificationValue: "1999-01",
- oldVerificationValue: "2000-01",
- },
- {
- verificationField: data.VerificationFieldPin,
- newVerificationValue: "1234",
- oldVerificationValue: "8901",
- },
- {
- verificationField: data.VerificationFieldNationalID,
- newVerificationValue: "NEWID123",
- oldVerificationValue: "OLDID890",
- },
- }
- for _, v := range receiverVerifications {
- newReceiverVerification := data.ReceiverVerification{}
- err = dbConnectionPool.GetContext(ctx, &newReceiverVerification, query, receiver.ID, v.verificationField)
- require.NoError(t, err)
+func Test_UpdateReceiverHandler_200ok_updateReceiverFields(t *testing.T) {
+ dbt := dbtest.Open(t)
+ defer dbt.Close()
+ dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN)
+ require.NoError(t, err)
+ defer dbConnectionPool.Close()
- assert.True(t, data.CompareVerificationValue(newReceiverVerification.HashedValue, v.newVerificationValue))
- assert.False(t, data.CompareVerificationValue(newReceiverVerification.HashedValue, v.oldVerificationValue))
+ models, err := data.NewModels(dbConnectionPool)
+ require.NoError(t, err)
- receiverDB, err := models.Receiver.Get(ctx, dbConnectionPool, receiver.ID)
- require.NoError(t, err)
- assert.Equal(t, "receiver@email.com", *receiverDB.Email)
- assert.Equal(t, "externalID", receiverDB.ExternalID)
- }
- })
+ handler := &UpdateReceiverHandler{
+ Models: models,
+ DBConnectionPool: dbConnectionPool,
+ }
- t.Run("updates and inserts receiver verifications values", func(t *testing.T) {
- data.DeleteAllReceiverVerificationFixtures(t, ctx, dbConnectionPool)
+ ctx := context.Background()
- request := validators.UpdateReceiverRequest{
- DateOfBirth: "1999-01-01",
- YearMonth: "1999-01",
- Pin: "1234",
- NationalID: "NEWID123",
- }
+ // setup
+ r := chi.NewRouter()
+ r.Patch("/receivers/{id}", handler.UpdateReceiver)
- route := fmt.Sprintf("/receivers/%s", receiver.ID)
- reqBody, err := json.Marshal(request)
- require.NoError(t, err)
- req, err := http.NewRequest(http.MethodPatch, route, strings.NewReader(string(reqBody)))
- require.NoError(t, err)
-
- rr := httptest.NewRecorder()
- r.ServeHTTP(rr, req)
-
- resp := rr.Result()
- assert.Equal(t, http.StatusOK, resp.StatusCode)
-
- query := `
- SELECT
- hashed_value
- FROM
- receiver_verifications
- WHERE
- receiver_id = $1 AND
- verification_field = $2
- `
-
- receiverVerifications := []struct {
- verificationField data.VerificationField
- newVerificationValue string
- oldVerificationValue string
- }{
- {
- verificationField: data.VerificationFieldDateOfBirth,
- newVerificationValue: "1999-01-01",
- oldVerificationValue: "2000-01-01",
+ testCases := []struct {
+ fieldName string
+ request validators.UpdateReceiverRequest
+ assertFn func(t *testing.T, receiver *data.Receiver)
+ }{
+ {
+ fieldName: "email",
+ request: validators.UpdateReceiverRequest{
+ Email: "update_receiver@email.com",
},
- {
- verificationField: data.VerificationFieldYearMonth,
- newVerificationValue: "1999-01",
- oldVerificationValue: "",
+ assertFn: func(t *testing.T, receiver *data.Receiver) {
+ assert.Equal(t, "update_receiver@email.com", receiver.Email)
},
- {
- verificationField: data.VerificationFieldPin,
- newVerificationValue: "1234",
- oldVerificationValue: "",
+ },
+ {
+ fieldName: "phone_number",
+ request: validators.UpdateReceiverRequest{
+ PhoneNumber: "+14155556666",
},
- {
- verificationField: data.VerificationFieldNationalID,
- newVerificationValue: "NEWID123",
- oldVerificationValue: "",
+ assertFn: func(t *testing.T, receiver *data.Receiver) {
+ assert.Equal(t, "+14155556666", receiver.PhoneNumber)
},
- }
- for _, v := range receiverVerifications {
- newReceiverVerification := data.ReceiverVerification{}
- err = dbConnectionPool.GetContext(ctx, &newReceiverVerification, query, receiver.ID, v.verificationField)
+ },
+ {
+ fieldName: "external_id",
+ request: validators.UpdateReceiverRequest{
+ ExternalID: "newExternalID",
+ },
+ assertFn: func(t *testing.T, receiver *data.Receiver) {
+ assert.Equal(t, "newExternalID", receiver.ExternalID)
+ },
+ },
+ {
+ fieldName: "ALL FIELDS",
+ request: validators.UpdateReceiverRequest{
+ Email: "update_receiver@email.com",
+ PhoneNumber: "+14155556666",
+ ExternalID: "newExternalID",
+ },
+ assertFn: func(t *testing.T, receiver *data.Receiver) {
+ assert.Equal(t, "update_receiver@email.com", receiver.Email)
+ assert.Equal(t, "+14155556666", receiver.PhoneNumber)
+ assert.Equal(t, "newExternalID", receiver.ExternalID)
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.fieldName, func(t *testing.T) {
+ defer data.DeleteAllReceiversFixtures(t, ctx, dbConnectionPool)
+ defer data.DeleteAllReceiverVerificationFixtures(t, ctx, dbConnectionPool)
+
+ receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, nil)
+ data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{
+ ReceiverID: receiver.ID,
+ VerificationField: data.VerificationTypeDateOfBirth,
+ VerificationValue: "2000-01-01",
+ })
+
+ route := fmt.Sprintf("/receivers/%s", receiver.ID)
+ reqBody, err := json.Marshal(tc.request)
+ require.NoError(t, err)
+ req, err := http.NewRequest(http.MethodPatch, route, strings.NewReader(string(reqBody)))
require.NoError(t, err)
- t.Logf("newReceiverVerification: %+v", newReceiverVerification)
- assert.True(t, data.CompareVerificationValue(newReceiverVerification.HashedValue, v.newVerificationValue))
+ rr := httptest.NewRecorder()
+ r.ServeHTTP(rr, req)
- if v.oldVerificationValue != "" {
- assert.False(t, data.CompareVerificationValue(newReceiverVerification.HashedValue, v.oldVerificationValue))
- }
+ resp := rr.Result()
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
receiverDB, err := models.Receiver.Get(ctx, dbConnectionPool, receiver.ID)
require.NoError(t, err)
- assert.Equal(t, "receiver@email.com", *receiverDB.Email)
- assert.Equal(t, "externalID", receiverDB.ExternalID)
- }
- })
- t.Run("updates receiver's email", func(t *testing.T) {
- request := validators.UpdateReceiverRequest{
- Email: "update_receiver@email.com",
- }
+ tc.assertFn(t, receiverDB)
+ })
+ }
+}
- route := fmt.Sprintf("/receivers/%s", receiver.ID)
- reqBody, err := json.Marshal(request)
- require.NoError(t, err)
+// upsertAction is a helper type to define the action to be taken by the handler when upserting the receiver verification.
+type upsertAction string
- req, err := http.NewRequest(http.MethodPatch, route, strings.NewReader(string(reqBody)))
- require.NoError(t, err)
+const (
+ actionUpdate upsertAction = "UPDATE"
+ actionInsert upsertAction = "INSERT"
+)
- rr := httptest.NewRecorder()
- r.ServeHTTP(rr, req)
+// shouldPreInsert is a helper function to determine if the receiver verification should be inserted before the request is
+// made, so we test if the handler is updating the verification value. Otherwise, the receiver verification will be inserted
+// as a consequence of the request.
+func (ua upsertAction) shouldPreInsert() bool {
+ return ua == actionUpdate
+}
- resp := rr.Result()
- assert.Equal(t, http.StatusOK, resp.StatusCode)
+func Test_UpdateReceiverHandler_200ok_upsertVerificationFields(t *testing.T) {
+ dbt := dbtest.Open(t)
+ defer dbt.Close()
+ dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN)
+ require.NoError(t, err)
+ defer dbConnectionPool.Close()
- receiverDB, err := models.Receiver.Get(ctx, dbConnectionPool, receiver.ID)
- require.NoError(t, err)
- assert.Equal(t, "update_receiver@email.com", *receiverDB.Email)
- })
+ models, err := data.NewModels(dbConnectionPool)
+ require.NoError(t, err)
- t.Run("updates receiver's external ID", func(t *testing.T) {
- request := validators.UpdateReceiverRequest{
- ExternalID: "newExternalID",
+ handler := &UpdateReceiverHandler{
+ Models: models,
+ DBConnectionPool: dbConnectionPool,
+ }
+
+ ctx := context.Background()
+
+ // setup
+ r := chi.NewRouter()
+ r.Patch("/receivers/{id}", handler.UpdateReceiver)
+
+ assertVerificationFieldsContains := func(t *testing.T, rvList []data.ReceiverVerification, vt data.VerificationType, verifValue string) {
+ var rv data.ReceiverVerification
+ for _, _rv := range rvList {
+ if _rv.VerificationField == vt {
+ rv = _rv
+ break
+ }
}
+ require.NotEmptyf(t, rv, "receiver verification of type %s not found", vt)
+
+ assert.Equal(t, vt, rv.VerificationField)
+ assert.True(t, data.CompareVerificationValue(rv.HashedValue, verifValue), "hashed value does not match")
+ }
+
+ testCases := []struct {
+ fieldName string
+ request validators.UpdateReceiverRequest
+ assertFn func(t *testing.T, rvList []data.ReceiverVerification)
+ }{
+ {
+ fieldName: "date_of_birth",
+ request: validators.UpdateReceiverRequest{
+ DateOfBirth: "2000-01-01",
+ },
+ assertFn: func(t *testing.T, rvList []data.ReceiverVerification) {
+ assertVerificationFieldsContains(t, rvList, data.VerificationTypeDateOfBirth, "2000-01-01")
+ },
+ },
+ {
+ fieldName: "year_month",
+ request: validators.UpdateReceiverRequest{
+ YearMonth: "2000-01",
+ },
+ assertFn: func(t *testing.T, rvList []data.ReceiverVerification) {
+ assertVerificationFieldsContains(t, rvList, data.VerificationTypeYearMonth, "2000-01")
+ },
+ },
+ {
+ fieldName: "pin",
+ request: validators.UpdateReceiverRequest{
+ Pin: "123456",
+ },
+ assertFn: func(t *testing.T, rvList []data.ReceiverVerification) {
+ assertVerificationFieldsContains(t, rvList, data.VerificationTypePin, "123456")
+ },
+ },
+ {
+ fieldName: "national_id",
+ request: validators.UpdateReceiverRequest{
+ NationalID: "abcd1234",
+ },
+ assertFn: func(t *testing.T, rvList []data.ReceiverVerification) {
+ assertVerificationFieldsContains(t, rvList, data.VerificationTypeNationalID, "abcd1234")
+ },
+ },
+ {
+ fieldName: "ALL FIELDS",
+ request: validators.UpdateReceiverRequest{
+ DateOfBirth: "2000-01-01",
+ YearMonth: "2000-01",
+ Pin: "123456",
+ NationalID: "abcd1234",
+ },
+ assertFn: func(t *testing.T, rvList []data.ReceiverVerification) {
+ assertVerificationFieldsContains(t, rvList, data.VerificationTypeDateOfBirth, "2000-01-01")
+ assertVerificationFieldsContains(t, rvList, data.VerificationTypeYearMonth, "2000-01")
+ assertVerificationFieldsContains(t, rvList, data.VerificationTypePin, "123456")
+ assertVerificationFieldsContains(t, rvList, data.VerificationTypeNationalID, "abcd1234")
+ },
+ },
+ }
- route := fmt.Sprintf("/receivers/%s", receiver.ID)
- reqBody, err := json.Marshal(request)
- require.NoError(t, err)
+ for _, action := range []upsertAction{actionUpdate, actionInsert} {
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("%s/%s", action, tc.fieldName), func(t *testing.T) {
+ defer data.DeleteAllReceiversFixtures(t, ctx, dbConnectionPool)
+ defer data.DeleteAllReceiverVerificationFixtures(t, ctx, dbConnectionPool)
+
+ receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, nil)
+
+ if action.shouldPreInsert() {
+ data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{
+ ReceiverID: receiver.ID,
+ VerificationField: data.VerificationTypeDateOfBirth,
+ VerificationValue: "1999-01-01",
+ })
+ data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{
+ ReceiverID: receiver.ID,
+ VerificationField: data.VerificationTypeYearMonth,
+ VerificationValue: "1999-01",
+ })
+ data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{
+ ReceiverID: receiver.ID,
+ VerificationField: data.VerificationTypePin,
+ VerificationValue: "000000",
+ })
+ data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{
+ ReceiverID: receiver.ID,
+ VerificationField: data.VerificationTypeNationalID,
+ VerificationValue: "aaaa0000",
+ })
+ }
- req, err := http.NewRequest(http.MethodPatch, route, strings.NewReader(string(reqBody)))
- require.NoError(t, err)
+ route := fmt.Sprintf("/receivers/%s", receiver.ID)
+ reqBody, err := json.Marshal(tc.request)
+ require.NoError(t, err)
+ req, err := http.NewRequest(http.MethodPatch, route, strings.NewReader(string(reqBody)))
+ require.NoError(t, err)
- rr := httptest.NewRecorder()
- r.ServeHTTP(rr, req)
+ rr := httptest.NewRecorder()
+ r.ServeHTTP(rr, req)
- resp := rr.Result()
- assert.Equal(t, http.StatusOK, resp.StatusCode)
+ resp := rr.Result()
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
- receiverDB, err := models.Receiver.Get(ctx, dbConnectionPool, receiver.ID)
- require.NoError(t, err)
+ rvSlice, err := models.ReceiverVerification.GetAllByReceiverId(ctx, dbConnectionPool, receiver.ID)
+ require.NoError(t, err)
- assert.Equal(t, "newExternalID", receiverDB.ExternalID)
- })
+ tc.assertFn(t, rvSlice)
+ })
+ }
+ }
}
diff --git a/internal/serve/httphandler/user_handler_test.go b/internal/serve/httphandler/user_handler_test.go
index 77892ed50..b49500784 100644
--- a/internal/serve/httphandler/user_handler_test.go
+++ b/internal/serve/httphandler/user_handler_test.go
@@ -686,7 +686,7 @@ func Test_UserHandler_CreateUser(t *testing.T) {
forgotPasswordLink, err := urllib.JoinPath(uiBaseURL, "forgot-password")
require.NoError(t, err)
- content, err := htmltemplate.ExecuteHTMLTemplateForInvitationMessage(htmltemplate.InvitationMessageTemplate{
+ content, err := htmltemplate.ExecuteHTMLTemplateForStaffInvitationEmailMessage(htmltemplate.StaffInvitationEmailMessageTemplate{
FirstName: u.FirstName,
Role: u.Roles[0],
ForgotPasswordLink: forgotPasswordLink,
@@ -697,7 +697,7 @@ func Test_UserHandler_CreateUser(t *testing.T) {
msg := message.Message{
ToEmail: u.Email,
Title: "Welcome to Stellar Disbursement Platform",
- Message: content,
+ Body: content,
}
messengerClientMock.
On("SendMessage", msg).
@@ -853,7 +853,7 @@ func Test_UserHandler_CreateUser(t *testing.T) {
forgotPasswordLink, err := urllib.JoinPath(uiBaseURL, "forgot-password")
require.NoError(t, err)
- content, err := htmltemplate.ExecuteHTMLTemplateForInvitationMessage(htmltemplate.InvitationMessageTemplate{
+ content, err := htmltemplate.ExecuteHTMLTemplateForStaffInvitationEmailMessage(htmltemplate.StaffInvitationEmailMessageTemplate{
FirstName: u.FirstName,
Role: u.Roles[0],
ForgotPasswordLink: forgotPasswordLink,
@@ -864,7 +864,7 @@ func Test_UserHandler_CreateUser(t *testing.T) {
msg := message.Message{
ToEmail: u.Email,
Title: "Welcome to Stellar Disbursement Platform",
- Message: content,
+ Body: content,
}
messengerClientMock.
On("SendMessage", msg).
diff --git a/internal/serve/httphandler/verifiy_receiver_registration_handler.go b/internal/serve/httphandler/verify_receiver_registration_handler.go
similarity index 83%
rename from internal/serve/httphandler/verifiy_receiver_registration_handler.go
rename to internal/serve/httphandler/verify_receiver_registration_handler.go
index 3f47552c3..9d7261d7f 100644
--- a/internal/serve/httphandler/verifiy_receiver_registration_handler.go
+++ b/internal/serve/httphandler/verify_receiver_registration_handler.go
@@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"net/http"
+ "strings"
"time"
"github.com/stellar/go/support/log"
@@ -17,6 +18,7 @@ import (
"github.com/stellar/stellar-disbursement-platform-backend/internal/data"
"github.com/stellar/stellar-disbursement-platform-backend/internal/events"
"github.com/stellar/stellar-disbursement-platform-backend/internal/events/schemas"
+ "github.com/stellar/stellar-disbursement-platform-backend/internal/message"
"github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror"
"github.com/stellar/stellar-disbursement-platform-backend/internal/serve/validators"
"github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/engine/signing"
@@ -114,39 +116,51 @@ func (v VerifyReceiverRegistrationHandler) processReceiverVerificationPII(
receiverRegistrationRequest data.ReceiverRegistrationRequest,
) error {
now := time.Now()
- truncatedPhoneNumber := utils.TruncateString(receiver.PhoneNumber, 3)
// STEP 1: find the receiverVerification entry that matches the pair [receiverID, verificationType]
- receiverVerifications, err := v.Models.ReceiverVerification.GetByReceiverIDsAndVerificationField(ctx, dbTx, []string{receiver.ID}, receiverRegistrationRequest.VerificationType)
+ receiverVerifications, err := v.Models.ReceiverVerification.GetByReceiverIDsAndVerificationField(ctx, dbTx, []string{receiver.ID}, receiverRegistrationRequest.VerificationField)
if err != nil {
- return fmt.Errorf("error retrieving receiver verification for verification type %s: %w", receiverRegistrationRequest.VerificationType, err)
+ return fmt.Errorf("retrieving receiver verification for verification type %s: %w", receiverRegistrationRequest.VerificationField, err)
}
if len(receiverVerifications) == 0 {
- err = fmt.Errorf("%s not found for receiver with phone number %s", receiverRegistrationRequest.VerificationType, truncatedPhoneNumber)
+ err = fmt.Errorf("verification of type %s not found for receiver id %s", receiverRegistrationRequest.VerificationField, receiver.ID)
return &ErrorInformationNotFound{cause: err}
}
if len(receiverVerifications) > 1 {
- log.Ctx(ctx).Warnf("receiver with id %s has more than one verification saved in the database for type %s", receiver.ID, receiverRegistrationRequest.VerificationType)
+ log.Ctx(ctx).Warnf("receiver with id %s has more than one verification saved in the database for type %s", receiver.ID, receiverRegistrationRequest.VerificationField)
}
receiverVerification := receiverVerifications[0]
// STEP 2: check if the number of attempts to confirm the verification value has already exceeded the max value
if v.Models.ReceiverVerification.ExceededAttempts(receiverVerification.Attempts) {
// TODO: the application currently can't recover from a max attempts exceeded error.
- err = fmt.Errorf("the number of attempts to confirm the verification value exceededs the max attempts")
+ err = fmt.Errorf("the number of attempts to confirm the verification value exceeded the max attempts")
return &ErrorVerificationAttemptsExceeded{cause: err}
}
// STEP 3: check if the payload verification value matches the one saved in the database
+ rvu := data.ReceiverVerificationUpdate{
+ ReceiverID: receiverVerification.ReceiverID,
+ VerificationField: receiverVerification.VerificationField,
+ }
+
+ if strings.TrimSpace(receiverRegistrationRequest.PhoneNumber) != "" {
+ rvu.VerificationChannel = message.MessageChannelSMS
+ } else if strings.TrimSpace(receiverRegistrationRequest.Email) != "" {
+ rvu.VerificationChannel = message.MessageChannelEmail
+ } else {
+ err = fmt.Errorf("no valid verification channel found resolved for receiver")
+ return &ErrorInformationNotFound{cause: err}
+ }
+
if !data.CompareVerificationValue(receiverVerification.HashedValue, receiverRegistrationRequest.VerificationValue) {
- baseErrMsg := fmt.Sprintf("%s value does not match for user with phone number %s", receiverRegistrationRequest.VerificationType, truncatedPhoneNumber)
+ baseErrMsg := fmt.Sprintf("%s value does not match for receiver with id %s", receiverRegistrationRequest.VerificationField, receiver.ID)
// update the receiver verification with the confirmation that the value was checked
- receiverVerification.Attempts = receiverVerification.Attempts + 1
- receiverVerification.FailedAt = &now
- receiverVerification.ConfirmedAt = nil
+ rvu.Attempts = utils.IntPtr(receiverVerification.Attempts + 1)
+ rvu.FailedAt = &now
- // this update is done using the DBConnectionPool and not dbTx because we don't want to roolback these changes after returning the error
- updateErr := v.Models.ReceiverVerification.UpdateReceiverVerification(ctx, *receiverVerification, v.Models.DBConnectionPool)
+ // this update is done using the DBConnectionPool and not dbTx because we don't want to rollback these changes after returning the error
+ updateErr := v.Models.ReceiverVerification.UpdateReceiverVerification(ctx, rvu, v.Models.DBConnectionPool)
if updateErr != nil {
err = fmt.Errorf("%s: %w", baseErrMsg, updateErr)
} else {
@@ -158,8 +172,9 @@ func (v VerifyReceiverRegistrationHandler) processReceiverVerificationPII(
// STEP 4: update the receiver verification row with the confirmation that the value was successfully validated
if receiverVerification.ConfirmedAt == nil {
- receiverVerification.ConfirmedAt = &now
- err = v.Models.ReceiverVerification.UpdateReceiverVerification(ctx, *receiverVerification, dbTx)
+ rvu.ConfirmedAt = &now
+
+ err = v.Models.ReceiverVerification.UpdateReceiverVerification(ctx, rvu, dbTx)
if err != nil {
return fmt.Errorf("updating successfully verified user: %w", err)
}
@@ -174,6 +189,7 @@ func (v VerifyReceiverRegistrationHandler) processReceiverWalletOTP(
dbTx db.DBTransaction,
sep24Claims anchorplatform.SEP24JWTClaims,
receiver data.Receiver, otp string,
+ contactInfo string,
) (receiverWallet data.ReceiverWallet, wasAlreadyRegistered bool, err error) {
// STEP 1: find the receiver wallet for the given [receiverID, clientDomain]
rw, err := v.Models.ReceiverWallet.GetByReceiverIDAndWalletDomain(ctx, receiver.ID, sep24Claims.ClientDomain(), dbTx)
@@ -205,6 +221,7 @@ func (v VerifyReceiverRegistrationHandler) processReceiverWalletOTP(
// STEP 5: update receiver wallet status to "REGISTERED"
now := time.Now()
rw.OTPConfirmedAt = &now
+ rw.OTPConfirmedWith = contactInfo
rw.Status = data.RegisteredReceiversWalletStatus
rw.StellarAddress = sep24Claims.SEP10StellarAccount()
rw.StellarMemo = sep24Claims.SEP10StellarMemo()
@@ -212,7 +229,14 @@ func (v VerifyReceiverRegistrationHandler) processReceiverWalletOTP(
if sep24Claims.SEP10StellarMemo() != "" {
rw.StellarMemoType = "id"
}
- err = v.Models.ReceiverWallet.UpdateReceiverWallet(ctx, *rw, dbTx)
+ err = v.Models.ReceiverWallet.Update(ctx, rw.ID, data.ReceiverWalletUpdate{
+ Status: rw.Status,
+ StellarAddress: rw.StellarAddress,
+ StellarMemo: rw.StellarMemo,
+ StellarMemoType: rw.StellarMemoType,
+ OTPConfirmedAt: now,
+ OTPConfirmedWith: rw.OTPConfirmedWith,
+ }, dbTx)
if err != nil {
err = fmt.Errorf("completing receiver wallet registration: %w", err)
return receiverWallet, false, err
@@ -225,8 +249,9 @@ func (v VerifyReceiverRegistrationHandler) processReceiverWalletOTP(
// the receiver wallet with the anchor platform transaction ID.
func (v VerifyReceiverRegistrationHandler) processAnchorPlatformID(ctx context.Context, dbTx db.DBTransaction, sep24Claims anchorplatform.SEP24JWTClaims, receiverWallet data.ReceiverWallet) error {
// STEP 1: update receiver wallet with the anchor platform transaction ID.
- receiverWallet.AnchorPlatformTransactionID = sep24Claims.TransactionID()
- err := v.Models.ReceiverWallet.UpdateReceiverWallet(ctx, receiverWallet, dbTx)
+ err := v.Models.ReceiverWallet.Update(ctx, receiverWallet.ID, data.ReceiverWalletUpdate{
+ AnchorPlatformTransactionID: sep24Claims.TransactionID(),
+ }, dbTx)
if err != nil {
return fmt.Errorf("updating receiver wallet with anchor platform transaction ID: %w", err)
}
@@ -261,19 +286,29 @@ func (v VerifyReceiverRegistrationHandler) VerifyReceiverRegistration(w http.Res
return
}
- truncatedPhoneNumber := utils.TruncateString(receiverRegistrationRequest.PhoneNumber, 3)
+ var contactInfo string
+ if receiverRegistrationRequest.PhoneNumber != "" {
+ contactInfo = receiverRegistrationRequest.PhoneNumber
+ } else if receiverRegistrationRequest.Email != "" {
+ contactInfo = receiverRegistrationRequest.Email
+ } else {
+ httperror.InternalError(ctx, "Unexpected contact info", nil, nil).Render(w)
+ return
+ }
+
+ truncatedContactInfo := utils.TruncateString(contactInfo, 3)
opts := db.TransactionOptions{
DBConnectionPool: v.Models.DBConnectionPool,
AtomicFunctionWithPostCommit: func(dbTx db.DBTransaction) (postCommitFn db.PostCommitFunction, err error) {
// STEP 2: find the receivers with the given phone number
- receivers, err := v.Models.Receiver.GetByPhoneNumbers(ctx, dbTx, []string{receiverRegistrationRequest.PhoneNumber})
+ receivers, err := v.Models.Receiver.GetByContacts(ctx, dbTx, contactInfo)
if err != nil {
- err = fmt.Errorf("error retrieving receiver with phone number %s: %w", truncatedPhoneNumber, err)
+ err = fmt.Errorf("retrieving receiver with contact info %s: %w", truncatedContactInfo, err)
return nil, err
}
if len(receivers) == 0 {
- err = fmt.Errorf("receiver with phone number %s not found in our server", truncatedPhoneNumber)
+ err = fmt.Errorf("receiver with contact info %s not found in our server", truncatedContactInfo)
return nil, &ErrorInformationNotFound{cause: err}
}
@@ -281,13 +316,13 @@ func (v VerifyReceiverRegistrationHandler) VerifyReceiverRegistration(w http.Res
receiver := receivers[0]
err = v.processReceiverVerificationPII(ctx, dbTx, *receiver, receiverRegistrationRequest)
if err != nil {
- return nil, fmt.Errorf("processing receiver verification entry for receiver with phone number %s: %w", truncatedPhoneNumber, err)
+ return nil, fmt.Errorf("processing receiver verification entry for receiver with contact info %s: %w", truncatedContactInfo, err)
}
// STEP 4: process OTP
- receiverWallet, wasAlreadyRegistered, err := v.processReceiverWalletOTP(ctx, dbTx, *sep24Claims, *receiver, receiverRegistrationRequest.OTP)
+ receiverWallet, wasAlreadyRegistered, err := v.processReceiverWalletOTP(ctx, dbTx, *sep24Claims, *receiver, receiverRegistrationRequest.OTP, contactInfo)
if err != nil {
- return nil, fmt.Errorf("processing OTP for receiver with phone number %s: %w", truncatedPhoneNumber, err)
+ return nil, fmt.Errorf("processing OTP for receiver with contact info %s: %w", truncatedContactInfo, err)
}
// STEP 5: build event message to trigger a transaction in the TSS
diff --git a/internal/serve/httphandler/verifiy_receiver_registration_handler_test.go b/internal/serve/httphandler/verify_receiver_registration_handler_test.go
similarity index 91%
rename from internal/serve/httphandler/verifiy_receiver_registration_handler_test.go
rename to internal/serve/httphandler/verify_receiver_registration_handler_test.go
index 8b2600a18..b15ea78c4 100644
--- a/internal/serve/httphandler/verifiy_receiver_registration_handler_test.go
+++ b/internal/serve/httphandler/verify_receiver_registration_handler_test.go
@@ -27,9 +27,11 @@ import (
"github.com/stellar/stellar-disbursement-platform-backend/internal/data"
"github.com/stellar/stellar-disbursement-platform-backend/internal/events"
"github.com/stellar/stellar-disbursement-platform-backend/internal/events/schemas"
+ "github.com/stellar/stellar-disbursement-platform-backend/internal/message"
"github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror"
"github.com/stellar/stellar-disbursement-platform-backend/internal/serve/validators"
sigMocks "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/engine/signing/mocks"
+ "github.com/stellar/stellar-disbursement-platform-backend/internal/utils"
"github.com/stellar/stellar-disbursement-platform-backend/pkg/schema"
"github.com/stellar/stellar-disbursement-platform-backend/stellar-multitenant/pkg/tenant"
)
@@ -100,7 +102,7 @@ func Test_VerifyReceiverRegistrationHandler_validate(t *testing.T) {
"phone_number": "+380445555555",
"otp": "",
"verification": "1990-01-01",
- "verification_type": "date_of_birth",
+ "verification_field": "date_of_birth",
"reCAPTCHA_token": "token"
}`,
isRecaptchaValidFnResponse: []interface{}{true, nil},
@@ -113,7 +115,7 @@ func Test_VerifyReceiverRegistrationHandler_validate(t *testing.T) {
"phone_number": "+380445555555",
"otp": "123456",
"verification": "1990-01-01",
- "verification_type": "date_of_birth",
+ "verification_field": "date_of_birth",
"reCAPTCHA_token": "token"
}`,
isRecaptchaValidFnResponse: []interface{}{true, nil},
@@ -122,7 +124,7 @@ func Test_VerifyReceiverRegistrationHandler_validate(t *testing.T) {
PhoneNumber: "+380445555555",
OTP: "123456",
VerificationValue: "1990-01-01",
- VerificationType: data.VerificationFieldDateOfBirth,
+ VerificationField: data.VerificationTypeDateOfBirth,
ReCAPTCHAToken: "token",
},
},
@@ -203,18 +205,23 @@ func Test_VerifyReceiverRegistrationHandler_processReceiverVerificationPII(t *te
receiverWithExceededAttempts := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{PhoneNumber: "+380446666666"})
receiverVerificationExceededAttempts := data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{
ReceiverID: receiverWithExceededAttempts.ID,
- VerificationField: data.VerificationFieldDateOfBirth,
+ VerificationField: data.VerificationTypeDateOfBirth,
VerificationValue: "1990-01-01",
})
receiverVerificationExceededAttempts.Attempts = data.MaxAttemptsAllowed
- err = models.ReceiverVerification.UpdateReceiverVerification(ctx, *receiverVerificationExceededAttempts, dbConnectionPool)
+ err = models.ReceiverVerification.UpdateReceiverVerification(ctx, data.ReceiverVerificationUpdate{
+ ReceiverID: receiverWithExceededAttempts.ID,
+ VerificationField: data.VerificationTypeDateOfBirth,
+ Attempts: utils.IntPtr(data.MaxAttemptsAllowed),
+ VerificationChannel: message.MessageChannelSMS,
+ }, dbConnectionPool)
require.NoError(t, err)
// receiver with receiver_verification row:
receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{PhoneNumber: "+380445555555"})
_ = data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{
ReceiverID: receiver.ID,
- VerificationField: data.VerificationFieldDateOfBirth,
+ VerificationField: data.VerificationTypeDateOfBirth,
VerificationValue: "1990-01-01",
})
@@ -233,58 +240,58 @@ func Test_VerifyReceiverRegistrationHandler_processReceiverVerificationPII(t *te
receiver: *receiverMissingReceiverVerification,
registrationRequest: data.ReceiverRegistrationRequest{
PhoneNumber: receiverMissingReceiverVerification.PhoneNumber,
- VerificationType: data.VerificationFieldDateOfBirth,
+ VerificationField: data.VerificationTypeDateOfBirth,
VerificationValue: "1990-01-01",
},
- wantErrContains: "DATE_OF_BIRTH not found for receiver with phone number +38...333",
+ wantErrContains: "DATE_OF_BIRTH not found for receiver id " + receiverMissingReceiverVerification.ID,
},
{
name: "returns an error if the receiver does not have any receiverVerification row with the given verification type (YEAR_MONTH)",
receiver: *receiver,
registrationRequest: data.ReceiverRegistrationRequest{
PhoneNumber: receiver.PhoneNumber,
- VerificationType: data.VerificationFieldYearMonth,
+ VerificationField: data.VerificationTypeYearMonth,
VerificationValue: "1999-12",
},
- wantErrContains: "YEAR_MONTH not found for receiver with phone number +38...555",
+ wantErrContains: "YEAR_MONTH not found for receiver id " + receiver.ID,
},
{
name: "returns an error if the receiver does not have any receiverVerification row with the given verification type (NATIONAL_ID_NUMBER)",
receiver: *receiver,
registrationRequest: data.ReceiverRegistrationRequest{
PhoneNumber: receiver.PhoneNumber,
- VerificationType: data.VerificationFieldNationalID,
+ VerificationField: data.VerificationTypeNationalID,
VerificationValue: "123456",
},
- wantErrContains: "NATIONAL_ID_NUMBER not found for receiver with phone number +38...555",
+ wantErrContains: "NATIONAL_ID_NUMBER not found for receiver id " + receiver.ID,
},
{
name: "returns an error if the receiver has exceeded their max attempts to confirm the verification value",
receiver: *receiverWithExceededAttempts,
registrationRequest: data.ReceiverRegistrationRequest{
PhoneNumber: receiverWithExceededAttempts.PhoneNumber,
- VerificationType: data.VerificationFieldDateOfBirth,
+ VerificationField: data.VerificationTypeDateOfBirth,
VerificationValue: "1990-01-01",
},
- wantErrContains: "the number of attempts to confirm the verification value exceededs the max attempts",
+ wantErrContains: "the number of attempts to confirm the verification value exceeded the max attempts",
},
{
name: "returns an error if the varification value provided in the payload is different from the DB one",
receiver: *receiver,
registrationRequest: data.ReceiverRegistrationRequest{
PhoneNumber: receiver.PhoneNumber,
- VerificationType: data.VerificationFieldDateOfBirth,
+ VerificationField: data.VerificationTypeDateOfBirth,
VerificationValue: "1990-11-11", // <--- different from the DB one (1990-01-01)
},
shouldAssertAttemptsCount: true,
- wantErrContains: "DATE_OF_BIRTH value does not match for user with phone number +38...555",
+ wantErrContains: "DATE_OF_BIRTH value does not match for receiver with id " + receiver.ID,
},
{
name: "🎉 successfully process the verification value and updates it accordingly in the DB",
receiver: *receiver,
registrationRequest: data.ReceiverRegistrationRequest{
PhoneNumber: receiver.PhoneNumber,
- VerificationType: data.VerificationFieldDateOfBirth,
+ VerificationField: data.VerificationTypeDateOfBirth,
VerificationValue: "1990-01-01",
},
shouldAssertAttemptsCount: true,
@@ -303,7 +310,7 @@ func Test_VerifyReceiverRegistrationHandler_processReceiverVerificationPII(t *te
var receiverVerifications []*data.ReceiverVerification
var receiverVerificationInitial *data.ReceiverVerification
if tc.shouldAssertAttemptsCount {
- receiverVerifications, err = models.ReceiverVerification.GetByReceiverIDsAndVerificationField(ctx, dbTx, []string{tc.receiver.ID}, tc.registrationRequest.VerificationType)
+ receiverVerifications, err = models.ReceiverVerification.GetByReceiverIDsAndVerificationField(ctx, dbTx, []string{tc.receiver.ID}, tc.registrationRequest.VerificationField)
require.NoError(t, err)
require.Len(t, receiverVerifications, 1)
receiverVerificationInitial = receiverVerifications[0]
@@ -312,7 +319,7 @@ func Test_VerifyReceiverRegistrationHandler_processReceiverVerificationPII(t *te
err = handler.processReceiverVerificationPII(ctx, dbTx, tc.receiver, tc.registrationRequest)
if tc.wantErrContains == "" {
- receiverVerifications, err = models.ReceiverVerification.GetByReceiverIDsAndVerificationField(ctx, dbTx, []string{tc.receiver.ID}, tc.registrationRequest.VerificationType)
+ receiverVerifications, err = models.ReceiverVerification.GetByReceiverIDsAndVerificationField(ctx, dbTx, []string{tc.receiver.ID}, tc.registrationRequest.VerificationField)
require.NoError(t, err)
require.Len(t, receiverVerifications, 1)
receiverVerification := receiverVerifications[0]
@@ -321,7 +328,7 @@ func Test_VerifyReceiverRegistrationHandler_processReceiverVerificationPII(t *te
} else {
require.ErrorContains(t, err, tc.wantErrContains)
if tc.shouldAssertAttemptsCount {
- receiverVerifications, err = models.ReceiverVerification.GetByReceiverIDsAndVerificationField(ctx, dbTx, []string{tc.receiver.ID}, tc.registrationRequest.VerificationType)
+ receiverVerifications, err = models.ReceiverVerification.GetByReceiverIDsAndVerificationField(ctx, dbTx, []string{tc.receiver.ID}, tc.registrationRequest.VerificationField)
require.NoError(t, err)
require.Len(t, receiverVerifications, 1)
receiverVerification := receiverVerifications[0]
@@ -420,6 +427,7 @@ func Test_VerifyReceiverRegistrationHandler_processReceiverWalletOTP(t *testing.
if !tc.shouldOTPMatch {
otp = wrongOTP
}
+ receiverEmail := "test@stellar.org"
// receiver & receiver wallet
receiver := data.CreateReceiverFixture(t, ctx, dbTx, &data.Receiver{PhoneNumber: "+380445555555"})
@@ -428,23 +436,25 @@ func Test_VerifyReceiverRegistrationHandler_processReceiverWalletOTP(t *testing.
receiverWallet = data.CreateReceiverWalletFixture(t, ctx, dbTx, receiver.ID, wallet.ID, tc.currentReceiverWalletStatus)
var stellarAddress string
var otpConfirmedAt *time.Time
+ var otpConfirmedWith string
if tc.wantWasAlreadyRegistered {
stellarAddress = "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444"
now := time.Now()
otpConfirmedAt = &now
+ otpConfirmedWith = receiverEmail
}
const q = `
UPDATE receiver_wallets
- SET otp = $1, otp_created_at = NOW(), stellar_address = $2, otp_confirmed_at = $3
- WHERE id = $4
+ SET otp = $1, otp_created_at = NOW(), stellar_address = $2, otp_confirmed_at = $3, otp_confirmed_with = $4
+ WHERE id = $5
`
- _, err = dbTx.ExecContext(ctx, q, correctOTP, sql.NullString{String: stellarAddress, Valid: stellarAddress != ""}, otpConfirmedAt, receiverWallet.ID)
+ _, err = dbTx.ExecContext(ctx, q, correctOTP, sql.NullString{String: stellarAddress, Valid: stellarAddress != ""}, otpConfirmedAt, otpConfirmedWith, receiverWallet.ID)
require.NoError(t, err)
}
// assertions
- rwUpdated, wasAlreadyRegistered, err := handler.processReceiverWalletOTP(ctx, dbTx, *tc.sep24Claims, *receiver, otp)
+ rwUpdated, wasAlreadyRegistered, err := handler.processReceiverWalletOTP(ctx, dbTx, *tc.sep24Claims, *receiver, otp, receiverEmail)
if tc.wantErrContains == nil {
require.NoError(t, err)
assert.Equal(t, tc.wantWasAlreadyRegistered, wasAlreadyRegistered)
@@ -459,6 +469,7 @@ func Test_VerifyReceiverRegistrationHandler_processReceiverWalletOTP(t *testing.
assert.Equal(t, rwUpdated.StellarAddress, rw.StellarAddress)
assert.NotNil(t, rw.OTPConfirmedAt)
assert.NotNil(t, rwUpdated.OTPConfirmedAt)
+ assert.Equal(t, rwUpdated.OTPConfirmedWith, receiverEmail)
assert.WithinDuration(t, *rwUpdated.OTPConfirmedAt, *rw.OTPConfirmedAt, time.Millisecond)
} else {
@@ -484,7 +495,7 @@ func Test_VerifyReceiverRegistrationHandler_processAnchorPlatformID(t *testing.T
handler := &VerifyReceiverRegistrationHandler{Models: models}
// creeate fixtures
- const phoneNumber = "+380445555555"
+ phoneNumber := "+380445555555"
defer data.DeleteAllFixtures(t, ctx, dbConnectionPool)
wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "testWallet", "https://home.page", "home.page", "wallet123://")
receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{PhoneNumber: phoneNumber})
@@ -574,7 +585,6 @@ func Test_VerifyReceiverRegistrationHandler_buildPaymentsReadyToPayEventMessage(
wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "testWallet", "https://home.page", "home.page", "wallet123://")
asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV")
- country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "UKR", "Ukraine")
receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{})
rw := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.RegisteredReceiversWalletStatus)
@@ -587,10 +597,9 @@ func Test_VerifyReceiverRegistrationHandler_buildPaymentsReadyToPayEventMessage(
}
pausedDisbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Wallet: wallet,
- Asset: asset,
- Country: country,
- Status: data.PausedDisbursementStatus,
+ Wallet: wallet,
+ Asset: asset,
+ Status: data.PausedDisbursementStatus,
})
_ = data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{
@@ -627,10 +636,9 @@ func Test_VerifyReceiverRegistrationHandler_buildPaymentsReadyToPayEventMessage(
}
disbursement := data.CreateDisbursementFixture(t, ctxWithoutTenant, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Wallet: wallet,
- Asset: asset,
- Country: country,
- Status: data.StartedDisbursementStatus,
+ Wallet: wallet,
+ Asset: asset,
+ Status: data.StartedDisbursementStatus,
})
_ = data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{
@@ -661,10 +669,9 @@ func Test_VerifyReceiverRegistrationHandler_buildPaymentsReadyToPayEventMessage(
}
disbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Wallet: wallet,
- Asset: asset,
- Country: country,
- Status: data.StartedDisbursementStatus,
+ Wallet: wallet,
+ Asset: asset,
+ Status: data.StartedDisbursementStatus,
})
payment := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{
@@ -710,10 +717,9 @@ func Test_VerifyReceiverRegistrationHandler_buildPaymentsReadyToPayEventMessage(
}
disbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Wallet: wallet,
- Asset: asset,
- Country: country,
- Status: data.StartedDisbursementStatus,
+ Wallet: wallet,
+ Asset: asset,
+ Status: data.StartedDisbursementStatus,
})
payment := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{
@@ -767,16 +773,28 @@ func Test_VerifyReceiverRegistrationHandler_VerifyReceiverRegistration(t *testin
},
}
- const phoneNumber = "+380445555555"
- receiverRegistrationRequest := data.ReceiverRegistrationRequest{
+ phoneNumber := "+380445555555"
+ receiverRegistrationRequestWithPhone := data.ReceiverRegistrationRequest{
PhoneNumber: phoneNumber,
OTP: "123456",
VerificationValue: "1990-01-01",
- VerificationType: "date_of_birth",
+ VerificationField: "date_of_birth",
ReCAPTCHAToken: "token",
}
- reqBody, err := json.Marshal(receiverRegistrationRequest)
+ reqBody, err := json.Marshal(receiverRegistrationRequestWithPhone)
require.NoError(t, err)
+
+ email := "test@stellar.org"
+ receiverRegistrationRequestWithEmail := data.ReceiverRegistrationRequest{
+ Email: email,
+ OTP: "123456",
+ VerificationValue: "1990-01-01",
+ VerificationField: "date_of_birth",
+ ReCAPTCHAToken: "token",
+ }
+ reqBodyEmail, err := json.Marshal(receiverRegistrationRequestWithEmail)
+ require.NoError(t, err)
+
r := chi.NewRouter()
t.Run("returns an error when validate() fails - testing case where a SEP24 claims are missing from the context", func(t *testing.T) {
@@ -837,7 +855,7 @@ func Test_VerifyReceiverRegistrationHandler_VerifyReceiverRegistration(t *testin
assert.JSONEq(t, wantBody, string(respBody))
// validate logs
- require.Contains(t, buf.String(), "receiver with phone number +38...555 not found in our server")
+ require.Contains(t, buf.String(), "receiver with contact info +38...555 not found in our server")
})
t.Run("returns an error when processReceiverVerificationPII() fails - testing case where no receiverVerification is found", func(t *testing.T) {
@@ -854,7 +872,7 @@ func Test_VerifyReceiverRegistrationHandler_VerifyReceiverRegistration(t *testin
// update database with the entries needed
defer data.DeleteAllReceiversFixtures(t, ctx, dbConnectionPool)
- _ = data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{PhoneNumber: phoneNumber})
+ receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{PhoneNumber: phoneNumber})
// set the logger to a buffer so we can check the error message
buf := new(strings.Builder)
@@ -877,7 +895,8 @@ func Test_VerifyReceiverRegistrationHandler_VerifyReceiverRegistration(t *testin
assert.JSONEq(t, wantBody, string(respBody))
// validate logs
- require.Contains(t, buf.String(), "processing receiver verification entry for receiver with phone number +38...555: DATE_OF_BIRTH not found for receiver with phone number +38...555")
+ expectedErr := `processing receiver verification entry for receiver with contact info +38...555: verification of type %s not found for receiver id %s`
+ require.Contains(t, buf.String(), fmt.Sprintf(expectedErr, data.VerificationTypeDateOfBirth, receiver.ID))
})
t.Run("returns an error when processReceiverVerificationPII() fails - testing case where maximum number of verification attempts exceeded", func(t *testing.T) {
@@ -898,11 +917,16 @@ func Test_VerifyReceiverRegistrationHandler_VerifyReceiverRegistration(t *testin
receiverWithExceededAttempts := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{PhoneNumber: phoneNumber})
receiverVerificationExceededAttempts := data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{
ReceiverID: receiverWithExceededAttempts.ID,
- VerificationField: data.VerificationFieldDateOfBirth,
+ VerificationField: data.VerificationTypeDateOfBirth,
VerificationValue: "1990-01-01",
})
receiverVerificationExceededAttempts.Attempts = data.MaxAttemptsAllowed
- err = models.ReceiverVerification.UpdateReceiverVerification(ctx, *receiverVerificationExceededAttempts, dbConnectionPool)
+ err = models.ReceiverVerification.UpdateReceiverVerification(ctx, data.ReceiverVerificationUpdate{
+ ReceiverID: receiverWithExceededAttempts.ID,
+ VerificationField: data.VerificationTypeDateOfBirth,
+ Attempts: utils.IntPtr(data.MaxAttemptsAllowed),
+ VerificationChannel: message.MessageChannelSMS,
+ }, dbConnectionPool)
require.NoError(t, err)
// set the logger to a buffer so we can check the error message
@@ -922,7 +946,7 @@ func Test_VerifyReceiverRegistrationHandler_VerifyReceiverRegistration(t *testin
respBody, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
- expectedError := "the number of attempts to confirm the verification value exceededs the max attempts"
+ expectedError := "the number of attempts to confirm the verification value exceeded the max attempts"
wantBody := fmt.Sprintf(`{"error": "%s"}`, expectedError)
assert.JSONEq(t, wantBody, string(respBody))
@@ -948,7 +972,7 @@ func Test_VerifyReceiverRegistrationHandler_VerifyReceiverRegistration(t *testin
receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{PhoneNumber: phoneNumber})
_ = data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{
ReceiverID: receiver.ID,
- VerificationField: data.VerificationFieldDateOfBirth,
+ VerificationField: data.VerificationTypeDateOfBirth,
VerificationValue: "1990-01-01",
})
@@ -973,7 +997,7 @@ func Test_VerifyReceiverRegistrationHandler_VerifyReceiverRegistration(t *testin
assert.JSONEq(t, wantBody, string(respBody))
// validate logs
- wantErrContains := fmt.Sprintf("processing OTP for receiver with phone number +38...555: receiver wallet not found for receiverID=%s and clientDomain=home.page", receiver.ID)
+ wantErrContains := fmt.Sprintf("processing OTP for receiver with contact info +38...555: receiver wallet not found for receiverID=%s and clientDomain=home.page", receiver.ID)
require.Contains(t, buf.String(), wantErrContains)
})
@@ -1008,11 +1032,11 @@ func Test_VerifyReceiverRegistrationHandler_VerifyReceiverRegistration(t *testin
receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{PhoneNumber: phoneNumber})
_ = data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{
ReceiverID: receiver.ID,
- VerificationField: data.VerificationFieldDateOfBirth,
+ VerificationField: data.VerificationTypeDateOfBirth,
VerificationValue: "1990-01-01",
})
_ = data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.ReadyReceiversWalletStatus)
- _, err := models.ReceiverWallet.UpdateOTPByReceiverPhoneNumberAndWalletDomain(ctx, "+380445555555", wallet.SEP10ClientDomain, "123456")
+ _, err := models.ReceiverWallet.UpdateOTPByReceiverContactInfoAndWalletDomain(ctx, "+380445555555", wallet.SEP10ClientDomain, "123456")
require.NoError(t, err)
// set the logger to a buffer so we can check the error message
@@ -1088,11 +1112,11 @@ func Test_VerifyReceiverRegistrationHandler_VerifyReceiverRegistration(t *testin
receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{PhoneNumber: phoneNumber})
_ = data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{
ReceiverID: receiver.ID,
- VerificationField: data.VerificationFieldDateOfBirth,
+ VerificationField: data.VerificationTypeDateOfBirth,
VerificationValue: "1990-01-01",
})
receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.ReadyReceiversWalletStatus)
- _, err := models.ReceiverWallet.UpdateOTPByReceiverPhoneNumberAndWalletDomain(ctx, "+380445555555", wallet.SEP10ClientDomain, "123456")
+ _, err := models.ReceiverWallet.UpdateOTPByReceiverContactInfoAndWalletDomain(ctx, "+380445555555", wallet.SEP10ClientDomain, "123456")
require.NoError(t, err)
// setup router and execute request
@@ -1190,19 +1214,19 @@ func Test_VerifyReceiverRegistrationHandler_VerifyReceiverRegistration(t *testin
defer data.DeleteAllReceiversFixtures(t, ctx, dbConnectionPool)
defer data.DeleteAllReceiverVerificationFixtures(t, ctx, dbConnectionPool)
defer data.DeleteAllReceiverWalletsFixtures(t, ctx, dbConnectionPool)
- receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{PhoneNumber: phoneNumber})
+ receiver := data.InsertReceiverFixture(t, ctx, dbConnectionPool, &data.ReceiverInsert{Email: &email})
_ = data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{
ReceiverID: receiver.ID,
- VerificationField: data.VerificationFieldDateOfBirth,
+ VerificationField: data.VerificationTypeDateOfBirth,
VerificationValue: "1990-01-01",
})
receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.ReadyReceiversWalletStatus)
- _, err := models.ReceiverWallet.UpdateOTPByReceiverPhoneNumberAndWalletDomain(ctx, "+380445555555", wallet.SEP10ClientDomain, "123456")
+ _, err := models.ReceiverWallet.UpdateOTPByReceiverContactInfoAndWalletDomain(ctx, email, wallet.SEP10ClientDomain, "123456")
require.NoError(t, err)
// setup router and execute request
r.Post("/wallet-registration/verification", handler.VerifyReceiverRegistration)
- req, err := http.NewRequest("POST", "/wallet-registration/verification", strings.NewReader(string(reqBody)))
+ req, err := http.NewRequest("POST", "/wallet-registration/verification", strings.NewReader(string(reqBodyEmail)))
require.NoError(t, err)
req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, &sep24Claims))
rr := httptest.NewRecorder()
@@ -1246,12 +1270,12 @@ func Test_VerifyReceiverRegistrationHandler_VerifyReceiverRegistration(t *testin
// registering Second Wallet
receiverWallet2 := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet2.ID, data.ReadyReceiversWalletStatus)
- _, err = models.ReceiverWallet.UpdateOTPByReceiverPhoneNumberAndWalletDomain(ctx, "+380445555555", wallet2.SEP10ClientDomain, "123456")
+ _, err = models.ReceiverWallet.UpdateOTPByReceiverContactInfoAndWalletDomain(ctx, email, wallet2.SEP10ClientDomain, "123456")
require.NoError(t, err)
sep24Claims.ClientDomainClaim = wallet2.SEP10ClientDomain
- req, err = http.NewRequest("POST", "/wallet-registration/verification", strings.NewReader(string(reqBody)))
+ req, err = http.NewRequest("POST", "/wallet-registration/verification", strings.NewReader(string(reqBodyEmail)))
require.NoError(t, err)
req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, &sep24Claims))
rr = httptest.NewRecorder()
@@ -1301,7 +1325,6 @@ func Test_VerifyReceiverRegistrationHandler_VerifyReceiverRegistration(t *testin
// update database with the entries needed
defer data.DeleteAllAssetFixtures(t, ctx, dbConnectionPool)
- defer data.DeleteAllCountryFixtures(t, ctx, dbConnectionPool)
defer data.DeleteAllDisbursementFixtures(t, ctx, dbConnectionPool)
defer data.DeleteAllReceiversFixtures(t, ctx, dbConnectionPool)
defer data.DeleteAllReceiverVerificationFixtures(t, ctx, dbConnectionPool)
@@ -1311,21 +1334,19 @@ func Test_VerifyReceiverRegistrationHandler_VerifyReceiverRegistration(t *testin
receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{PhoneNumber: phoneNumber})
_ = data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{
ReceiverID: receiver.ID,
- VerificationField: data.VerificationFieldDateOfBirth,
+ VerificationField: data.VerificationTypeDateOfBirth,
VerificationValue: "1990-01-01",
})
receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.ReadyReceiversWalletStatus)
- _, err := models.ReceiverWallet.UpdateOTPByReceiverPhoneNumberAndWalletDomain(ctx, "+380445555555", wallet.SEP10ClientDomain, "123456")
+ _, err := models.ReceiverWallet.UpdateOTPByReceiverContactInfoAndWalletDomain(ctx, "+380445555555", wallet.SEP10ClientDomain, "123456")
require.NoError(t, err)
// Creating a payment ready to pay
asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV")
- country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "UKR", "Ukraine")
disbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Wallet: wallet,
- Asset: asset,
- Country: country,
- Status: data.StartedDisbursementStatus,
+ Wallet: wallet,
+ Asset: asset,
+ Status: data.StartedDisbursementStatus,
})
payment := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{
Amount: "100",
diff --git a/internal/serve/httphandler/wallets_handler.go b/internal/serve/httphandler/wallets_handler.go
index 28d925c32..7721f524b 100644
--- a/internal/serve/httphandler/wallets_handler.go
+++ b/internal/serve/httphandler/wallets_handler.go
@@ -4,7 +4,6 @@ import (
"errors"
"fmt"
"net/http"
- "strconv"
"github.com/go-chi/chi/v5"
"github.com/stellar/go/support/http/httpdecode"
@@ -13,35 +12,52 @@ import (
"github.com/stellar/stellar-disbursement-platform-backend/internal/data"
"github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror"
"github.com/stellar/stellar-disbursement-platform-backend/internal/serve/validators"
+ "github.com/stellar/stellar-disbursement-platform-backend/internal/utils"
)
type WalletsHandler struct {
- Models *data.Models
+ Models *data.Models
+ NetworkType utils.NetworkType
}
// GetWallets returns a list of wallets
func (h WalletsHandler) GetWallets(w http.ResponseWriter, r *http.Request) {
- context := r.Context()
+ ctx := r.Context()
- enabledParam := r.URL.Query().Get("enabled")
- var enabledFilter *bool
- if enabledParam != "" {
- enabledValue, err := strconv.ParseBool(enabledParam)
- if err != nil {
- httperror.BadRequest("Invalid enabled parameter value", nil, nil).Render(w)
- return
- }
- enabledFilter = &enabledValue
+ filters, err := h.parseFilters(r)
+ if err != nil {
+ extras := map[string]interface{}{"validation_error": err.Error()}
+ httperror.BadRequest("Error parsing request filters", nil, extras).Render(w)
+ return
}
- wallets, err := h.Models.Wallets.FindWallets(context, enabledFilter)
+ wallets, err := h.Models.Wallets.FindWallets(ctx, filters...)
if err != nil {
- httperror.InternalError(context, "Cannot retrieve list of wallets", err, nil).Render(w)
+ httperror.InternalError(ctx, "Cannot retrieve list of wallets", err, nil).Render(w)
return
}
httpjson.Render(w, wallets, httpjson.JSON)
}
+func (h WalletsHandler) parseFilters(r *http.Request) ([]data.Filter, error) {
+ filters := []data.Filter{}
+ filterParams := map[string]data.FilterKey{
+ "enabled": data.FilterEnabledWallets,
+ "user_managed": data.FilterUserManaged,
+ }
+
+ for param, filterType := range filterParams {
+ paramValue, err := utils.ParseBoolQueryParam(r, param)
+ if err != nil {
+ return nil, fmt.Errorf("invalid '%s' parameter value", param)
+ }
+ if paramValue != nil {
+ filters = append(filters, data.NewFilter(filterType, *paramValue))
+ }
+ }
+ return filters, nil
+}
+
func (h WalletsHandler) PostWallets(rw http.ResponseWriter, req *http.Request) {
ctx := req.Context()
@@ -52,7 +68,7 @@ func (h WalletsHandler) PostWallets(rw http.ResponseWriter, req *http.Request) {
}
validator := validators.NewWalletValidator()
- reqBody = validator.ValidateCreateWalletRequest(ctx, reqBody)
+ reqBody = validator.ValidateCreateWalletRequest(ctx, reqBody, h.NetworkType.IsPubnet())
if validator.HasErrors() {
httperror.BadRequest("invalid request body", nil, validator.Errors).Render(rw)
return
@@ -67,6 +83,9 @@ func (h WalletsHandler) PostWallets(rw http.ResponseWriter, req *http.Request) {
})
if err != nil {
switch {
+ case errors.Is(err, data.ErrInvalidAssetID):
+ httperror.BadRequest(data.ErrInvalidAssetID.Error(), err, nil).Render(rw)
+ return
case errors.Is(err, data.ErrWalletNameAlreadyExists):
httperror.Conflict(data.ErrWalletNameAlreadyExists.Error(), err, nil).Render(rw)
return
@@ -76,9 +95,6 @@ func (h WalletsHandler) PostWallets(rw http.ResponseWriter, req *http.Request) {
case errors.Is(err, data.ErrWalletDeepLinkSchemaAlreadyExists):
httperror.Conflict(data.ErrWalletDeepLinkSchemaAlreadyExists.Error(), err, nil).Render(rw)
return
- case errors.Is(err, data.ErrInvalidAssetID):
- httperror.Conflict(data.ErrInvalidAssetID.Error(), err, nil).Render(rw)
- return
}
httperror.InternalError(ctx, "", err, nil).Render(rw)
diff --git a/internal/serve/httphandler/wallets_handler_test.go b/internal/serve/httphandler/wallets_handler_test.go
index 3ca54842c..ccbcd3851 100644
--- a/internal/serve/httphandler/wallets_handler_test.go
+++ b/internal/serve/httphandler/wallets_handler_test.go
@@ -17,6 +17,7 @@ import (
"github.com/stellar/stellar-disbursement-platform-backend/db"
"github.com/stellar/stellar-disbursement-platform-backend/db/dbtest"
"github.com/stellar/stellar-disbursement-platform-backend/internal/data"
+ "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror"
)
func Test_WalletsHandlerGetWallets(t *testing.T) {
@@ -111,6 +112,47 @@ func Test_WalletsHandlerGetWallets(t *testing.T) {
require.JSONEq(t, string(expectedJSON), string(respBody))
})
+ t.Run("successfully returns a list of user managed wallets", func(t *testing.T) {
+ wallets := data.ClearAndCreateWalletFixtures(t, ctx, dbConnectionPool)
+
+ // make first wallet user managed
+ data.MakeWalletUserManaged(t, ctx, dbConnectionPool, wallets[0].ID)
+
+ rr := httptest.NewRecorder()
+ req, _ := http.NewRequest("GET", "/wallets?user_managed=true", nil)
+ http.HandlerFunc(handler.GetWallets).ServeHTTP(rr, req)
+
+ resp := rr.Result()
+
+ respBody, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ respWallets := []data.Wallet{}
+ err = json.Unmarshal(respBody, &respWallets)
+ require.NoError(t, err)
+ assert.Equal(t, 1, len(respWallets))
+ assert.Equal(t, wallets[0].ID, respWallets[0].ID)
+ assert.Equal(t, wallets[0].Name, respWallets[0].Name)
+ })
+
+ t.Run("bad request when user_managed parameter isn't a bool", func(t *testing.T) {
+ rr := httptest.NewRecorder()
+ req, _ := http.NewRequest("GET", "/wallets?user_managed=xxx", nil)
+ http.HandlerFunc(handler.GetWallets).ServeHTTP(rr, req)
+
+ resp := rr.Result()
+
+ respBody, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ require.Equal(t, http.StatusBadRequest, resp.StatusCode)
+ var httpErr httperror.HTTPError
+ err = json.Unmarshal(respBody, &httpErr)
+ require.NoError(t, err)
+ assert.Equal(t, "invalid 'user_managed' parameter value", httpErr.Extras["validation_error"])
+ })
+
t.Run("bad request when enabled parameter isn't a bool", func(t *testing.T) {
rr := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/wallets?enabled=xxx", nil)
@@ -122,14 +164,16 @@ func Test_WalletsHandlerGetWallets(t *testing.T) {
require.NoError(t, err)
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
- require.JSONEq(t, `{"error": "Invalid enabled parameter value"}`, string(respBody))
+ var httpErr httperror.HTTPError
+ err = json.Unmarshal(respBody, &httpErr)
+ require.NoError(t, err)
+ assert.Equal(t, "invalid 'enabled' parameter value", httpErr.Extras["validation_error"])
})
}
func Test_WalletsHandlerPostWallets(t *testing.T) {
dbt := dbtest.Open(t)
defer dbt.Close()
-
dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN)
require.NoError(t, err)
defer dbConnectionPool.Close()
@@ -138,44 +182,30 @@ func Test_WalletsHandlerPostWallets(t *testing.T) {
require.NoError(t, err)
ctx := context.Background()
+ handler := &WalletsHandler{Models: models}
- handler := &WalletsHandler{
- Models: models,
- }
-
- data.DeleteAllWalletFixtures(t, ctx, dbConnectionPool)
+ // Fixture setup
+ wallet := data.ClearAndCreateWalletFixtures(t, ctx, dbConnectionPool)[0]
asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "XLM", "")
- t.Run("returns BadRequest when payload is invalid", func(t *testing.T) {
- rr := httptest.NewRecorder()
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, "/wallets", strings.NewReader(`invalid`))
- require.NoError(t, err)
-
- http.HandlerFunc(handler.PostWallets).ServeHTTP(rr, req)
-
- resp := rr.Result()
-
- respBody, err := io.ReadAll(resp.Body)
- require.NoError(t, err)
- defer resp.Body.Close()
-
- assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
- assert.JSONEq(t, `{"error": "The request was invalid in some way."}`, string(respBody))
-
- rr = httptest.NewRecorder()
- req, err = http.NewRequestWithContext(ctx, http.MethodPost, "/wallets", strings.NewReader(`{}`))
- require.NoError(t, err)
-
- http.HandlerFunc(handler.PostWallets).ServeHTTP(rr, req)
-
- resp = rr.Result()
-
- respBody, err = io.ReadAll(resp.Body)
- require.NoError(t, err)
- defer resp.Body.Close()
-
- expected := `
- {
+ // Define test cases
+ testCases := []struct {
+ name string
+ payload string
+ expectedStatus int
+ expectedBody string
+ }{
+ {
+ name: "🔴-400-BadRequest when payload is invalid",
+ payload: `invalid`,
+ expectedStatus: http.StatusBadRequest,
+ expectedBody: `{"error": "The request was invalid in some way."}`,
+ },
+ {
+ name: "🔴-400-BadRequest when payload is missing required fields",
+ payload: `{}`,
+ expectedStatus: http.StatusBadRequest,
+ expectedBody: `{
"error": "invalid request body",
"extras": {
"name": "name is required",
@@ -184,215 +214,134 @@ func Test_WalletsHandlerPostWallets(t *testing.T) {
"sep_10_client_domain": "sep_10_client_domain is required",
"assets_ids": "provide at least one asset ID"
}
- }
- `
- assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
- assert.JSONEq(t, expected, string(respBody))
-
- payload := `
- {
+ }`,
+ },
+ {
+ name: "🔴-400-BadRequest when assets_ids is missing",
+ payload: `{
"name": "New Wallet",
"homepage": "https://newwallet.com",
"deep_link_schema": "newwallet://sdp",
"sep_10_client_domain": "https://newwallet.com"
- }
- `
- rr = httptest.NewRecorder()
- req, err = http.NewRequestWithContext(ctx, http.MethodPost, "/wallets", strings.NewReader(payload))
- require.NoError(t, err)
-
- http.HandlerFunc(handler.PostWallets).ServeHTTP(rr, req)
-
- resp = rr.Result()
-
- respBody, err = io.ReadAll(resp.Body)
- require.NoError(t, err)
- defer resp.Body.Close()
-
- expected = `
- {
+ }`,
+ expectedStatus: http.StatusBadRequest,
+ expectedBody: `{
"error": "invalid request body",
"extras": {
"assets_ids": "provide at least one asset ID"
}
- }
- `
- assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
- assert.JSONEq(t, expected, string(respBody))
- })
-
- t.Run("returns BadRequest when the URLs are invalids", func(t *testing.T) {
- payload := fmt.Sprintf(`
- {
+ }`,
+ },
+ {
+ name: "🔴-400-BadRequest when URLs are invalid",
+ payload: fmt.Sprintf(`{
"name": "New Wallet",
"homepage": "newwallet.com",
"deep_link_schema": "deeplink/sdp",
"sep_10_client_domain": "https://newwallet.com",
"assets_ids": [%q]
- }
- `, asset.ID)
- rr := httptest.NewRecorder()
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, "/wallets", strings.NewReader(payload))
- require.NoError(t, err)
-
- http.HandlerFunc(handler.PostWallets).ServeHTTP(rr, req)
-
- resp := rr.Result()
-
- respBody, err := io.ReadAll(resp.Body)
- require.NoError(t, err)
- defer resp.Body.Close()
-
- expected := `
- {
+ }`, asset.ID),
+ expectedStatus: http.StatusBadRequest,
+ expectedBody: `{
"error": "invalid request body",
"extras": {
"deep_link_schema": "invalid deep link schema provided",
"homepage": "invalid homepage URL provided"
}
- }
- `
- assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
- assert.JSONEq(t, expected, string(respBody))
- })
-
- t.Run("returns Conflict when creating a duplicated wallet", func(t *testing.T) {
- wallet := data.ClearAndCreateWalletFixtures(t, ctx, dbConnectionPool)[0]
-
- // Duplicated Name
- payload := fmt.Sprintf(`
- {
+ }`,
+ },
+ {
+ name: "🔴-400-BadRequest when creating a wallet with an invalid asset ID",
+ payload: `{
+ "name": "New Wallet",
+ "homepage": "https://newwallet.com",
+ "deep_link_schema": "newwallet://sdp",
+ "sep_10_client_domain": "https://newwallet.com",
+ "assets_ids": ["invalid-asset-id"]
+ }`,
+ expectedStatus: http.StatusBadRequest,
+ expectedBody: `{"error": "invalid asset ID"}`,
+ },
+ {
+ name: "🔴-409-Conflict when creating a duplicated wallet (name)",
+ payload: fmt.Sprintf(`{
"name": %q,
- "homepage": %q,
- "deep_link_schema": %q,
- "sep_10_client_domain": %q,
+ "homepage": "https://newwallet.com",
+ "deep_link_schema": "newwallet://sdp",
+ "sep_10_client_domain": "https://newwallet.com",
"assets_ids": [%q]
- }
- `, wallet.Name, wallet.Homepage, wallet.DeepLinkSchema, wallet.SEP10ClientDomain, asset.ID)
- rr := httptest.NewRecorder()
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, "/wallets", strings.NewReader(payload))
- require.NoError(t, err)
-
- http.HandlerFunc(handler.PostWallets).ServeHTTP(rr, req)
-
- resp := rr.Result()
-
- respBody, err := io.ReadAll(resp.Body)
- require.NoError(t, err)
- defer resp.Body.Close()
-
- assert.Equal(t, http.StatusConflict, resp.StatusCode)
- assert.JSONEq(t, `{"error": "a wallet with this name already exists"}`, string(respBody))
-
- // Duplicated Homepage
- payload = fmt.Sprintf(`
- {
+ }`, wallet.Name, asset.ID),
+ expectedStatus: http.StatusConflict,
+ expectedBody: `{"error": "a wallet with this name already exists"}`,
+ },
+ {
+ name: "🔴-409-Conflict when creating a duplicated wallet (homepage)",
+ payload: fmt.Sprintf(`{
"name": "New Wallet",
"homepage": %q,
- "deep_link_schema": %q,
- "sep_10_client_domain": %q,
+ "deep_link_schema": "newwallet://sdp",
+ "sep_10_client_domain": "https://newwallet.com",
"assets_ids": [%q]
- }
- `, wallet.Homepage, wallet.DeepLinkSchema, wallet.SEP10ClientDomain, asset.ID)
- rr = httptest.NewRecorder()
- req, err = http.NewRequestWithContext(ctx, http.MethodPost, "/wallets", strings.NewReader(payload))
- require.NoError(t, err)
-
- http.HandlerFunc(handler.PostWallets).ServeHTTP(rr, req)
-
- resp = rr.Result()
-
- respBody, err = io.ReadAll(resp.Body)
- require.NoError(t, err)
- defer resp.Body.Close()
-
- assert.Equal(t, http.StatusConflict, resp.StatusCode)
- assert.JSONEq(t, `{"error": "a wallet with this homepage already exists"}`, string(respBody))
-
- // Duplicated Deep Link Schema
- payload = fmt.Sprintf(`
- {
+ }`, wallet.Homepage, asset.ID),
+ expectedStatus: http.StatusConflict,
+ expectedBody: `{"error": "a wallet with this homepage already exists"}`,
+ },
+ {
+ name: "🔴-409-Conflict when creating a duplicated wallet (deep_link_schema)",
+ payload: fmt.Sprintf(`{
"name": "New Wallet",
"homepage": "https://newwallet.com",
"deep_link_schema": %q,
- "sep_10_client_domain": %q,
+ "sep_10_client_domain": "https://newwallet.com",
"assets_ids": [%q]
- }
- `, wallet.DeepLinkSchema, wallet.SEP10ClientDomain, asset.ID)
- rr = httptest.NewRecorder()
- req, err = http.NewRequestWithContext(ctx, http.MethodPost, "/wallets", strings.NewReader(payload))
- require.NoError(t, err)
-
- http.HandlerFunc(handler.PostWallets).ServeHTTP(rr, req)
-
- resp = rr.Result()
-
- respBody, err = io.ReadAll(resp.Body)
- require.NoError(t, err)
- defer resp.Body.Close()
-
- assert.Equal(t, http.StatusConflict, resp.StatusCode)
- assert.JSONEq(t, `{"error": "a wallet with this deep link schema already exists"}`, string(respBody))
-
- // Invalid asset ID
- payload = fmt.Sprintf(`
- {
- "name": "New Wallet",
- "homepage": "https://newwallet.com",
- "deep_link_schema": "newwallet://sdp",
- "sep_10_client_domain": %q,
- "assets_ids": ["asset-id"]
- }
- `, wallet.SEP10ClientDomain)
- rr = httptest.NewRecorder()
- req, err = http.NewRequestWithContext(ctx, http.MethodPost, "/wallets", strings.NewReader(payload))
- require.NoError(t, err)
-
- http.HandlerFunc(handler.PostWallets).ServeHTTP(rr, req)
-
- resp = rr.Result()
-
- respBody, err = io.ReadAll(resp.Body)
- require.NoError(t, err)
- defer resp.Body.Close()
-
- assert.Equal(t, http.StatusConflict, resp.StatusCode)
- assert.JSONEq(t, `{"error": "invalid asset ID"}`, string(respBody))
- })
-
- t.Run("creates wallet successfully", func(t *testing.T) {
- data.DeleteAllWalletFixtures(t, ctx, dbConnectionPool)
-
- payload := fmt.Sprintf(`
- {
+ }`, wallet.DeepLinkSchema, asset.ID),
+ expectedStatus: http.StatusConflict,
+ expectedBody: `{"error": "a wallet with this deep link schema already exists"}`,
+ },
+ {
+ name: "🟢-successfully creates wallet",
+ payload: fmt.Sprintf(`{
"name": "New Wallet",
"homepage": "https://newwallet.com",
"deep_link_schema": "newwallet://deeplink/sdp",
"sep_10_client_domain": "https://newwallet.com",
"assets_ids": [%q]
- }
- `, asset.ID)
- rr := httptest.NewRecorder()
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, "/wallets", strings.NewReader(payload))
- require.NoError(t, err)
-
- http.HandlerFunc(handler.PostWallets).ServeHTTP(rr, req)
-
- resp := rr.Result()
-
- assert.Equal(t, http.StatusCreated, resp.StatusCode)
-
- wallet, err := models.Wallets.GetByWalletName(ctx, "New Wallet")
- require.NoError(t, err)
-
- walletAssets, err := models.Wallets.GetAssets(ctx, wallet.ID)
- require.NoError(t, err)
+ }`, asset.ID),
+ expectedStatus: http.StatusCreated,
+ expectedBody: "",
+ },
+ }
- assert.Equal(t, "https://newwallet.com", wallet.Homepage)
- assert.Equal(t, "newwallet://deeplink/sdp", wallet.DeepLinkSchema)
- assert.Equal(t, "newwallet.com", wallet.SEP10ClientDomain)
- assert.Len(t, walletAssets, 1)
- })
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ rr := httptest.NewRecorder()
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, "/wallets", strings.NewReader(tc.payload))
+ require.NoError(t, err)
+
+ http.HandlerFunc(handler.PostWallets).ServeHTTP(rr, req)
+
+ resp := rr.Result()
+ defer resp.Body.Close()
+ respBody, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ assert.Equal(t, tc.expectedStatus, resp.StatusCode)
+ if tc.expectedBody != "" {
+ assert.JSONEq(t, tc.expectedBody, string(respBody))
+ } else if tc.expectedStatus == http.StatusCreated {
+ wallet, err := models.Wallets.GetByWalletName(ctx, "New Wallet")
+ require.NoError(t, err)
+
+ walletAssets, err := models.Wallets.GetAssets(ctx, wallet.ID)
+ require.NoError(t, err)
+
+ assert.Equal(t, "https://newwallet.com", wallet.Homepage)
+ assert.Equal(t, "newwallet://deeplink/sdp", wallet.DeepLinkSchema)
+ assert.Equal(t, "newwallet.com", wallet.SEP10ClientDomain)
+ assert.Len(t, walletAssets, 1)
+ }
+ })
+ }
}
func Test_WalletsHandlerDeleteWallet(t *testing.T) {
diff --git a/internal/serve/middleware/middleware_test.go b/internal/serve/middleware/middleware_test.go
index 09856d6c9..2e593c64f 100644
--- a/internal/serve/middleware/middleware_test.go
+++ b/internal/serve/middleware/middleware_test.go
@@ -13,7 +13,6 @@ import (
"github.com/go-chi/chi/v5"
"github.com/sirupsen/logrus"
-
"github.com/stellar/go/support/log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
diff --git a/internal/serve/publicfiles/js/receiver_registration.js b/internal/serve/publicfiles/js/receiver_registration.js
index 375e3bc4f..181b655ac 100644
--- a/internal/serve/publicfiles/js/receiver_registration.js
+++ b/internal/serve/publicfiles/js/receiver_registration.js
@@ -1,411 +1,442 @@
+// ------------------------------ START: ENUMS ------------------------------
+const ContactMethods = Object.freeze({
+ PHONE_NUMBER: "phone_number",
+ EMAIL: "email",
+});
+
+const CurrentSection = Object.freeze({
+ SELECT_OTP_METHOD: "selectOtpMethod", // SECTION 1
+ PHONE_NUMBER: "phoneNumber", // SECTION 2.1 (w/ phone_number)
+ EMAIL_ADDRESS: "emailAddress", // SECTION 2.2 (w/ email)
+ PASSCODE: "passcode", // SECTION 3
+});
+
+const VerificationField = Object.freeze({
+ DATE_OF_BIRTH: "DATE_OF_BIRTH",
+ YEAR_MONTH: "YEAR_MONTH",
+ NATIONAL_ID_NUMBER: "NATIONAL_ID_NUMBER",
+ PIN: "PIN",
+});
+// ------------------------------ END: ENUMS ------------------------------
+
+
+// ------------------------------ START: GLOBAL VARIABLES AND METHODS ------------------------------
+const reCAPTCHAWidgets = {};
+
+function resetReCAPTCHA() {
+ Object.values(reCAPTCHAWidgets).forEach(widgetId => grecaptcha.reset(widgetId));
+}
+
const WalletRegistration = {
jwtToken: "",
intlTelInput: null,
- phoneNumberErrorEl: null,
privacyPolicyLink: "",
-};
+ contactMethod: "",
+ currentSection: CurrentSection.SELECT_OTP_METHOD,
+ verificationField: "",
+
+ setSection(section) {
+ this.currentSection = section;
+ switch (section) {
+ case CurrentSection.PHONE_NUMBER:
+ this.contactMethod = ContactMethods.PHONE_NUMBER;
+ break;
+ case CurrentSection.EMAIL_ADDRESS:
+ this.contactMethod = ContactMethods.EMAIL;
+ break;
+ }
-function getJwtToken() {
- const tokenEl = document.querySelector("[data-jwt-token]");
+ Object.values(CurrentSection).forEach((s) => {
+ const sectionEl = document.querySelector(`[data-section='${s}']`);
+ if (sectionEl) sectionEl.style.display = s === section ? "flex" : "none";
+ });
+ },
+
+ errorNotificationEl() {
+ return document.querySelector("[data-section-error]");
+ },
+
+ successNotificationEl() {
+ return document.querySelector("[data-section-success]");
+ },
+
+ toggleErrorNotification(title, message, isVisible) {
+ const errorNotificationEl = this.errorNotificationEl();
+ toggleNotification("error", { parentEl: errorNotificationEl, title, message, isVisible });
+ },
+
+ toggleSuccessNotification(title, message, isVisible) {
+ const successNotificationEl = this.successNotificationEl();
+ toggleNotification("success", { parentEl: successNotificationEl, title, message, isVisible });
+ },
+
+ getRecaptchaToken() {
+ const tokenSelectorMap = {
+ [CurrentSection.EMAIL_ADDRESS]: "#g-recaptcha-response",
+ [CurrentSection.PHONE_NUMBER]: "#g-recaptcha-response-1",
+ [CurrentSection.PASSCODE]: "#g-recaptcha-response-2",
+ };
+
+ const recaptchaEl = document.querySelector(tokenSelectorMap[this.currentSection]);
+ return recaptchaEl?.value || "";
+ },
+
+ getSectionEl() {
+ return document.querySelector(`[data-section='${this.currentSection}']`);
+ },
+
+ toggleButtonsEnabled(isEnabled) {
+ const sectionEl = this.getSectionEl();
+ const buttonEls = sectionEl?.querySelectorAll("[data-button]");
+ if (!buttonEls) return;
+ const t = window.setTimeout(() => {
+ buttonEls.forEach((b) => {
+ b.disabled = !isEnabled;
+ });
- if (tokenEl) {
- return tokenEl.innerHTML;
- }
-}
+ clearTimeout(t);
+ }, isEnabled ? 1000 : 0);
+ },
-function getPrivacyPolicyLink() {
- const linkEl = document.querySelector("[data-privacy-policy-link]");
+ getContactValue() {
+ switch (this.contactMethod) {
+ case ContactMethods.PHONE_NUMBER:
+ return this.intlTelInput.getNumber().trim();
+ case ContactMethods.EMAIL:
+ return document.querySelector("#email_address").value.trim();
+ }
+ },
- if (linkEl) {
- return linkEl.innerHTML;
- }
-}
+ validateContactValue() {
+ const contactValue = this.getContactValue();
+ if (!contactValue) {
+ this.toggleErrorNotification("Error", "Contact information is required", true);
+ return -1;
+ }
-document.addEventListener("DOMContentLoaded", function () {
- const footer = document.getElementById("WalletRegistration__PrivacyPolicy");
+ if (this.contactMethod === ContactMethods.PHONE_NUMBER) {
+ if (!this.intlTelInput.isPossibleNumber()) {
+ this.toggleErrorNotification("Error", "Entered phone number is not valid", true);
+ return -1;
+ }
+ } else if (this.contactMethod === ContactMethods.EMAIL) {
+ const isValidEmail = (email) => /^[^\s@]+@[^\s@]+\.[^\s@]+$/.test(email);
+ if (!isValidEmail(contactValue)) {
+ this.toggleErrorNotification("Error", "Entered email is not valid", true);
+ return -1;
+ }
+ }
- if (WalletRegistration.privacyPolicyLink == "") {
- footer.style = "display: none"
+ this.toggleErrorNotification("", "", false);
+ return 0;
}
-});
+};
+// ------------------------------ END: GLOBAL VARIABLES AND METHODS ------------------------------
-function toggleNotification(type, { parentEl, title, message, isVisible }) {
- const titleEl = parentEl.querySelector(`[data-section-${type}-title]`);
- const messageEl = parentEl.querySelector(`[data-section-${type}-message`);
- if (titleEl && messageEl) {
- if (isVisible) {
- parentEl.style.display = "flex";
- titleEl.innerHTML = title;
- messageEl.innerHTML = message;
+// ------------------------------ START: INITIALIZATION ------------------------------
+window.onload = () => {
+ WalletRegistration.jwtToken = document.querySelector("#jwt-token").dataset.jwtToken
+ WalletRegistration.privacyPolicyLink = document.querySelector("[data-privacy-policy-link]")?.innerHTML || "";
+ WalletRegistration.intlTelInput = phoneNumberInit();
+
+ ensureRecaptchaLoaded();
+ grecaptcha.ready(function(){
+ // Render reCAPTCHA instances and store their widget IDs
+ const siteKey = document.querySelector("#recaptcha-site-key").dataset.sitekey;
+ reCAPTCHAWidgets.email = grecaptcha.render('g-recaptcha-email', { sitekey: siteKey });
+ reCAPTCHAWidgets.phone = grecaptcha.render('g-recaptcha-phone', { sitekey: siteKey });
+ reCAPTCHAWidgets.passcode = grecaptcha.render('g-recaptcha-passcode', { sitekey: siteKey });
+ });
+};
+
+// https://developers.google.com/recaptcha/docs/loading#loading_recaptcha_asynchronously
+function ensureRecaptchaLoaded() {
+ // How this code snippet works:
+ // This logic overwrites the default behavior of `grecaptcha.ready()` to
+ // ensure that it can be safely called at any time. When `grecaptcha.ready()`
+ // is called before reCAPTCHA is loaded, the callback function that is passed
+ // by `grecaptcha.ready()` is enqueued for execution after reCAPTCHA is
+ // loaded.
+ if(typeof grecaptcha === 'undefined') {
+ grecaptcha = {};
+ }
+ grecaptcha.ready = function(cb){
+ if(typeof grecaptcha === 'undefined') {
+ // window.__grecaptcha_cfg is a global variable that stores reCAPTCHA's
+ // configuration. By default, any functions listed in its 'fns' property
+ // are automatically executed when reCAPTCHA loads.
+ const c = '___grecaptcha_cfg';
+ window[c] = window[c] || {};
+ (window[c]['fns'] = window[c]['fns']||[]).push(cb);
} else {
- parentEl.style.display = "none";
- titleEl.innerHTML = "";
- messageEl.innerHTML = "";
+ cb();
}
}
}
-function toggleErrorNotification(parentEl, title, message, isVisible) {
- toggleNotification("error", { parentEl, title, message, isVisible });
-}
-
-function toggleSuccessNotification(parentEl, title, message, isVisible) {
- toggleNotification("success", { parentEl, title, message, isVisible });
-}
+// Phone number input (ref: https://github.com/jackocnr/intl-tel-input)
+function phoneNumberInit() {
+ const phoneNumberInput = document.querySelector("#phone_number");
-async function sendSms(phoneNumber, reCAPTCHAToken, onSuccess, onError) {
- if (phoneNumber && reCAPTCHAToken) {
- try {
- const response = await fetch("/wallet-registration/otp", {
- method: "POST",
- headers: {
- "Content-Type": "application/json",
- Authorization: `Bearer ${WalletRegistration.jwtToken}`,
- },
- body: JSON.stringify({
- phone_number: phoneNumber,
- recaptcha_token: reCAPTCHAToken,
- }),
- });
+ const intlTelInput = window.intlTelInput(phoneNumberInput, {
+ utilsScript: "/static/js/intl-tel-input-v18.2.1-utils.min.js",
+ separateDialCode: true,
+ preferredCountries: [],
+ // Excluding Cuba, Iran, North Korea, and Syria
+ excludeCountries: ["cu", "ir", "kp", "sy"],
+ // Setting default country based on user's IP address
+ initialCountry: "auto",
+ geoIpLookup: (callback) => {
+ fetch("https://ipapi.co/json")
+ .then((res) => res.json())
+ .then((data) => callback(data.country_code))
+ .catch(() => callback(""));
+ },
+ });
- const data = await response.json();
- if (!response.ok) {
- throw new Error(data.error || "Something went wrong, please try again later.");
+ // Clear phone number error message
+ const errorNotificationEl = WalletRegistration.errorNotificationEl();
+ ["change", "keyup"].forEach((event) => {
+ phoneNumberInput.addEventListener(event, () => {
+ if (errorNotificationEl.style.display !== "none") {
+ WalletRegistration.toggleErrorNotification("", "", false);
+ WalletRegistration.toggleButtonsEnabled(true);
}
-
- onSuccess(data.verification_field);;
- } catch (error) {
- onError(error);
- }
- }
+ });
+ });
+
+ return intlTelInput;
}
-function disableButtons(buttons) {
- buttons.forEach((b) => {
- b.disabled = true;
+document.addEventListener("DOMContentLoaded", function () {
+ // Hide Privacy Policy Link if not provided
+ const footer = document.getElementById("WalletRegistration__PrivacyPolicy");
+ if (!WalletRegistration.privacyPolicyLink) {
+ footer.style.display = "none";
+ }
+
+ // SECTION 1: Setup OTP Method Form
+ const otpMethodForm = document.getElementById("selectOtpMethodForm");
+ otpMethodForm?.addEventListener("change", () => {
+ WalletRegistration.toggleErrorNotification("", "", false);
+ });
+ otpMethodForm?.addEventListener("submit", (event) => {
+ event.preventDefault();
+ handleOtpSelected();
});
-}
-function enableButtons(buttons) {
- const t = window.setTimeout(() => {
- buttons.forEach((b) => {
- b.disabled = false;
+ // SECTION 2: Setup Email and Phone Number Forms
+ ["submitEmailForm", "submitPhoneNumberForm"].forEach((formId) => {
+ document.getElementById(formId)?.addEventListener("submit", (event) => {
+ event.preventDefault();
+ handleContactInfoSubmitted();
});
+ });
- clearTimeout(t);
- }, 1000);
-}
-
-document.addEventListener("DOMContentLoaded", function () {
- const form = document.getElementById("submitPhoneNumberForm");
+ // SECTION 3: Setup OTP Form
+ document.getElementById("submitVerificationForm")?.addEventListener("submit", (event) => {
+ event.preventDefault();
+ handleVerificationInfoSubmitted();
+ });
- form.addEventListener("submit", function (event) {
- submitPhoneNumber(event);
+ // SECTION 3: Setup Resend OTP Button
+ document.getElementById("resendOtpButton")?.addEventListener("click", (event) => {
+ event.preventDefault();
+ handleResendOtpClicked();
});
});
-async function submitPhoneNumber(event) {
- event.preventDefault();
- const phoneNumberEl = document.querySelector("#phone_number");
- const phoneNumberSectionEl = document.querySelector(
- "[data-section='phoneNumber']"
- );
- const passcodeSectionEl = document.querySelector("[data-section='passcode']");
- const errorNotificationEl = WalletRegistration.phoneNumberErrorEl;
- const reCAPTCHATokenEl = phoneNumberSectionEl.querySelector(
- "#g-recaptcha-response"
- );
- const buttonEls = phoneNumberSectionEl.querySelectorAll("[data-button]");
- const verificationFieldTitle = document.querySelector("label[for='verification']");
- const verificationFieldInput = document.querySelector("#verification");
-
- if (!reCAPTCHATokenEl || !reCAPTCHATokenEl.value) {
- toggleErrorNotification(
- errorNotificationEl,
- "Error",
- "reCAPTCHA is required",
- true
- );
+
+// ------------------------------ START: SECTION 1 ------------------------------
+function handleOtpSelected() {
+ const selectedMethod = document.querySelector('input[name="otp_method"]:checked')?.value;
+ if (!selectedMethod) {
+ WalletRegistration.toggleErrorNotification("Error", "Please select a contact method to receive your OTP", true);
return;
}
+ WalletRegistration.setSection(selectedMethod);
+}
+// ------------------------------ END: SECTION 1 ------------------------------
- toggleErrorNotification(errorNotificationEl, "", "", false);
-
- if (
- WalletRegistration.intlTelInput &&
- reCAPTCHATokenEl &&
- phoneNumberSectionEl &&
- passcodeSectionEl &&
- errorNotificationEl
- ) {
- disableButtons(buttonEls);
- const phoneNumber = WalletRegistration.intlTelInput.getNumber();
- const reCAPTCHAToken = reCAPTCHATokenEl.value;
-
- if (
- phoneNumberEl.value.trim() &&
- !WalletRegistration.intlTelInput.isPossibleNumber()
- ) {
- toggleErrorNotification(
- errorNotificationEl,
- "Error",
- "Entered phone number is not valid",
- true
- );
- return;
- }
- function showNextPage(verificationField) {
- verificationFieldInput.type = "text";
- if(verificationField === "DATE_OF_BIRTH") {
- verificationFieldTitle.textContent = "Date of birth";
- verificationFieldInput.name = "date_of_birth";
- verificationFieldInput.type = "date";
- }
- else if(verificationField === "YEAR_MONTH") {
- verificationFieldTitle.textContent = "Date of birth (Year/Month)";
- verificationFieldInput.name = "year_month";
- verificationFieldInput.type = "month";
- }
- else if(verificationField === "NATIONAL_ID_NUMBER") {
- verificationFieldTitle.textContent = "National ID number";
- verificationFieldInput.name = "national_id_number";
- }
- else if(verificationField === "PIN") {
- verificationFieldTitle.textContent = "Pin";
- verificationFieldInput.name = "pin";
- }
+// ------------------------------ START: SECTION 2 ------------------------------
+async function handleContactInfoSubmitted() {
+ if (![CurrentSection.PHONE_NUMBER, CurrentSection.EMAIL_ADDRESS].includes(WalletRegistration.currentSection)) {
+ alert("Invalid section to submit contact information: " + WalletRegistration.currentSection);
+ return;
+ }
- phoneNumberSectionEl.style.display = "none";
- reCAPTCHATokenEl.style.display = "none";
- passcodeSectionEl.style.display = "flex";
- enableButtons(buttonEls);
- }
+ const reCAPTCHAToken = WalletRegistration.getRecaptchaToken();
+ if (!reCAPTCHAToken) {
+ WalletRegistration.toggleErrorNotification("Error", "reCAPTCHA is required", true);
+ return;
+ }
- function showErrorMessage(error) {
- toggleErrorNotification(errorNotificationEl, "Error", error, true);
- enableButtons(buttonEls);
+ WalletRegistration.toggleErrorNotification("", "", false);
+ WalletRegistration.toggleButtonsEnabled(false);
+ if (WalletRegistration.validateContactValue() === -1) return;
+
+ function showNextPage(verificationField) {
+ const verificationFieldTitle = document.querySelector("label[for='verification']");
+ const verificationFieldInput = document.querySelector("#verification");
+ WalletRegistration.verificationField = verificationField;
+
+ const inputFeldConfigMap = {
+ [VerificationField.DATE_OF_BIRTH]: { name: "date_of_birth", type: "date", label: "Date of birth" },
+ [VerificationField.YEAR_MONTH]: { name: "year_month", type: "month", label: "Date of birth (Year/Month)" },
+ [VerificationField.NATIONAL_ID_NUMBER]: { name: "national_id_number", type: "text", label: "National ID number" },
+ [VerificationField.PIN]: { name: "pin", type: "text", label: "Pin" },
+ };
+
+ const inputFieldConfig = inputFeldConfigMap[verificationField];
+ if (inputFieldConfig) {
+ verificationFieldTitle.textContent = inputFieldConfig.label;
+ verificationFieldInput.name = inputFieldConfig.name;
+ verificationFieldInput.type = inputFieldConfig.type;
}
- sendSms(phoneNumber, reCAPTCHAToken, showNextPage, showErrorMessage);
+ WalletRegistration.setSection(CurrentSection.PASSCODE);
+ WalletRegistration.toggleButtonsEnabled(true);
+ }
+
+ function showErrorMessage(error) {
+ WalletRegistration.toggleErrorNotification("Error", error, true);
+ WalletRegistration.toggleButtonsEnabled(true);
}
+
+ sendOtp(showNextPage, showErrorMessage);
}
+// ------------------------------ END: SECTION 2 ------------------------------
-document.addEventListener("DOMContentLoaded", function () {
- const form = document.getElementById("submitOtpForm");
- form.addEventListener("submit", function (event) {
- submitOtp(event);
- });
-});
+// ------------------------------ START: SECTION 3 ------------------------------
+async function handleVerificationInfoSubmitted() {
+ const reCAPTCHAToken = WalletRegistration.getRecaptchaToken();
+ if (!reCAPTCHAToken) {
+ WalletRegistration.toggleErrorNotification("Error", "reCAPTCHA is required", true);
+ return;
+ }
-async function submitOtp(event) {
- event.preventDefault();
-
- const passcodeSectionEl = document.querySelector("[data-section='passcode']");
- const errorNotificationEl = document.querySelector(
- "[data-section-error='passcode']"
- );
- const successNotificationEl = document.querySelector(
- "[data-section-success='passcode']"
- );
- const otpEl = document.getElementById("otp");
- const verificationEl = document.getElementById("verification");
- const verificationField = verificationEl.getAttribute("name");
-
- const buttonEls = passcodeSectionEl.querySelectorAll("[data-button]");
-
- const reCAPTCHATokenEl = passcodeSectionEl.querySelector(
- "#g-recaptcha-response-1"
- );
- if (!reCAPTCHATokenEl || !reCAPTCHATokenEl.value) {
- toggleErrorNotification(
- errorNotificationEl,
- "Error",
- "reCAPTCHA is required",
- true
- );
+ const contactMethod = WalletRegistration.contactMethod;
+ const contactValue = WalletRegistration.getContactValue();
+ const otp = document.getElementById("otp").value;
+ const verificationFieldValue = document.getElementById("verification").value;
+ if (!contactMethod || !contactValue || !otp || !verificationFieldValue) {
+ const errMessage = `Missing one of the required fields: ${{ contactMethod, contactValue, otp, verificationFieldValue }}`;
+ WalletRegistration.toggleErrorNotification("Error", errMessage, true);
return;
}
- if (
- WalletRegistration.intlTelInput &&
- otpEl &&
- verificationEl &&
- passcodeSectionEl &&
- errorNotificationEl
- ) {
- toggleErrorNotification(errorNotificationEl, "", "", false);
- toggleSuccessNotification(successNotificationEl, "", "", false);
-
- const phoneNumber = WalletRegistration.intlTelInput.getNumber();
- const otp = otpEl.value;
- const verification = verificationEl.value;
-
- if (phoneNumber && otp && verification) {
- try {
- disableButtons(buttonEls);
-
- const response = await fetch("/wallet-registration/verification", {
- method: "POST",
- headers: {
- "Content-Type": "application/json",
- Authorization: `Bearer ${WalletRegistration.jwtToken}`,
- },
- body: JSON.stringify({
- phone_number: phoneNumber,
- otp: otp,
- verification: verification,
- verification_type: verificationField,
- recaptcha_token: reCAPTCHATokenEl.value,
- }),
- });
-
- if ([200, 201].includes(response.status)) {
- await response.json();
-
- const t = window.setTimeout(() => {
- location.reload();
- clearTimeout(t);
- }, 2000);
- } else if (response.status === 400) {
- const data = await response.json();
- const errorMessage = data.error || "Something went wrong, please try again later.";
- throw new Error(errorMessage);
- } else {
- throw new Error("Something went wrong, please try again later.");
- }
- } catch (error) {
- enableButtons(buttonEls);
- toggleErrorNotification(errorNotificationEl, "Error", error, true);
- grecaptcha.reset(1);
- }
+ WalletRegistration.toggleErrorNotification("", "", false);
+ WalletRegistration.toggleSuccessNotification("", "", false);
+
+ try {
+ WalletRegistration.toggleButtonsEnabled(false);
+
+ const response = await fetch("/wallet-registration/verification", {
+ method: "POST",
+ headers: {
+ "Content-Type": "application/json",
+ Authorization: `Bearer ${WalletRegistration.jwtToken}`,
+ },
+ body: JSON.stringify({
+ [contactMethod]: contactValue,
+ otp: otp,
+ recaptcha_token: reCAPTCHAToken,
+ verification_field: WalletRegistration.verificationField,
+ verification: verificationFieldValue,
+ }),
+ });
+
+ if (Math.floor(response.status / 100) === 2) {
+ await response.json();
+ setTimeout(() => {
+ location.reload();
+ }, 2000);
+ } else if (response.status === 400) {
+ const data = await response.json();
+ const errorMessage = data.error || "Something went wrong with your request, please try again later.";
+ throw new Error(errorMessage);
+ } else {
+ throw new Error(`Something went wrong, please try again later (status code: ${response.status}).`);
}
+ } catch (error) {
+ WalletRegistration.toggleButtonsEnabled(true);
+ WalletRegistration.toggleErrorNotification("Error", error, true);
+ resetReCAPTCHA();
}
}
-document.addEventListener("DOMContentLoaded", function () {
- const button = document.getElementById("resendSmsButton");
-
- button.addEventListener("click", function (event) {
- resendSms(event);
- });
-});
+async function handleResendOtpClicked() {
+ const reCAPTCHAToken = WalletRegistration.getRecaptchaToken();
+ if (!reCAPTCHAToken) {
+ WalletRegistration.toggleErrorNotification("Error", "reCAPTCHA is required", true);
+ return;
+ }
-async function resendSms() {
- const passcodeSectionEl = document.querySelector("[data-section='passcode']");
- const errorNotificationEl = document.querySelector(
- "[data-section-error='passcode']"
- );
- const successNotificationEl = document.querySelector(
- "[data-section-success='passcode']"
- );
- const buttonEls = passcodeSectionEl.querySelectorAll("[data-button]");
- const reCAPTCHATokenEl = passcodeSectionEl.querySelector(
- "#g-recaptcha-response-1"
- );
-
- if (!reCAPTCHATokenEl || !reCAPTCHATokenEl.value) {
- toggleErrorNotification(
- errorNotificationEl,
- "Error",
- "reCAPTCHA is required",
- true
- );
+ const contactValue = WalletRegistration.getContactValue();
+ if (!contactValue) {
+ WalletRegistration.toggleErrorNotification("Error", "Contact information is required", true);
return;
}
- if (
- (passcodeSectionEl,
- errorNotificationEl,
- WalletRegistration.intlTelInput,
- reCAPTCHATokenEl)
- ) {
- disableButtons(buttonEls);
- toggleErrorNotification(errorNotificationEl, "", "", false);
- toggleSuccessNotification(successNotificationEl, "", "", false);
-
- const phoneNumber = WalletRegistration.intlTelInput.getNumber();
- const reCAPTCHAToken = reCAPTCHATokenEl.value;
-
- function showErrorMessage(error) {
- toggleErrorNotification(errorNotificationEl, "Error", error, true);
- enableButtons(buttonEls);
- }
+ WalletRegistration.toggleButtonsEnabled(false);
+ WalletRegistration.toggleErrorNotification("", "", false);
+ WalletRegistration.toggleSuccessNotification("", "", false);
- function showSuccessMessage() {
- toggleSuccessNotification(
- successNotificationEl,
- "New SMS sent",
- "You will receive a new one-time passcode",
- true
- );
- enableButtons(buttonEls);
- }
+ function showErrorMessage(error) {
+ WalletRegistration.toggleErrorNotification("Error", error, true);
+ WalletRegistration.toggleButtonsEnabled(true);
+ }
- sendSms(phoneNumber, reCAPTCHAToken, showSuccessMessage, showErrorMessage);
- grecaptcha.reset(1);
+ function showSuccessMessage() {
+ WalletRegistration.toggleSuccessNotification("New OTP sent", "You will receive a new one-time passcode", true);
+ WalletRegistration.toggleButtonsEnabled(true);
}
+
+ sendOtp(showSuccessMessage, showErrorMessage);
+ resetReCAPTCHA();
}
+// ------------------------------ END: SECTION 3 ------------------------------
+
+
+// ------------------------------ START: UTILITY FUNCTIONS ------------------------------
+function toggleNotification(type, { parentEl, title, message, isVisible }) {
+ const titleEl = parentEl.querySelector(`[data-section-${type}-title]`);
+ const messageEl = parentEl.querySelector(`[data-section-${type}-message`);
-function resetNumberInputError(buttonEls) {
- if (
- WalletRegistration.phoneNumberErrorEl &&
- WalletRegistration.phoneNumberErrorEl.style.display !== "none"
- ) {
- toggleErrorNotification(
- WalletRegistration.phoneNumberErrorEl,
- "",
- "",
- false
- );
- enableButtons(buttonEls);
+ if (titleEl && messageEl) {
+ parentEl.style.display = isVisible ? "flex" : "none";
+ titleEl.innerHTML = isVisible ? title : "";
+ messageEl.innerHTML = isVisible ? message : "";
}
}
-// Phone number input
-// https://github.com/jackocnr/intl-tel-input
-function phoneNumberInit() {
- const phoneNumberInput = document.querySelector("#phone_number");
- const phoneNumberSectionEl = document.querySelector(
- "[data-section='phoneNumber']"
- );
- const buttonEls = phoneNumberSectionEl.querySelectorAll("[data-button]");
-
- const intlTelInput = window.intlTelInput(phoneNumberInput, {
- utilsScript: "/static/js/intl-tel-input-v18.2.1-utils.min.js",
- separateDialCode: true,
- preferredCountries: [],
- // Excluding Cuba, Iran, North Korea, and Syria
- excludeCountries: ["cu", "ir", "kp", "sy"],
- // Setting default country based on user's IP address
- initialCountry: "auto",
- geoIpLookup: (callback) => {
- fetch("https://ipapi.co/json")
- .then((res) => res.json())
- .then((data) => callback(data.country_code))
- .catch(() => callback(""));
- },
- });
+async function sendOtp(onSuccess, onError) {
+ const reqPayload = {
+ [WalletRegistration.contactMethod]: WalletRegistration.getContactValue(),
+ recaptcha_token: WalletRegistration.getRecaptchaToken(),
+ };
+
+ try {
+ const response = await fetch("/wallet-registration/otp", {
+ method: "POST",
+ headers: {
+ "Content-Type": "application/json",
+ Authorization: `Bearer ${WalletRegistration.jwtToken}`,
+ },
+ body: JSON.stringify(reqPayload),
+ });
- // Clear phone number error message
- phoneNumberInput.addEventListener("change", () =>
- resetNumberInputError(buttonEls)
- );
- phoneNumberInput.addEventListener("keyup", () =>
- resetNumberInputError(buttonEls)
- );
+ const data = await response.json();
+ if (!response.ok) {
+ throw new Error(data.error || "Something went wrong, please try again later.");
+ }
- return intlTelInput;
+ onSuccess(data.verification_field);
+ } catch (error) {
+ onError(error);
+ }
}
-
-// Init
-window.onload = async () => {
- WalletRegistration.jwtToken = getJwtToken();
- WalletRegistration.intlTelInput = phoneNumberInit();
- WalletRegistration.phoneNumberErrorEl = document.querySelector(
- "[data-section-error='phoneNumber']"
- );
- WalletRegistration.privacyPolicyLink = getPrivacyPolicyLink();
-};
+// ------------------------------ END: UTILITY FUNCTIONS ------------------------------
diff --git a/internal/serve/serve.go b/internal/serve/serve.go
index 77296800c..6f37cb534 100644
--- a/internal/serve/serve.go
+++ b/internal/serve/serve.go
@@ -64,7 +64,7 @@ type ServeOptions struct {
CorsAllowedOrigins []string
authManager auth.AuthManager
EmailMessengerClient message.MessengerClient
- SMSMessengerClient message.MessengerClient
+ MessageDispatcher message.MessageDispatcherInterface
SEP24JWTSecret string
sep24JWTManager *anchorplatform.JWTManager
BaseURL string
@@ -89,7 +89,7 @@ type ServeOptions struct {
DistributionAccountService services.DistributionAccountServiceInterface
DistAccEncryptionPassphrase string
EventProducer events.Producer
- MaxInvitationSMSResendAttempts int
+ MaxInvitationResendAttempts int
SingleTenantMode bool
CircleService circle.ServiceInterface
}
@@ -210,7 +210,6 @@ func handleHTTP(o ServeOptions) *chi.Mux {
httprate.WithKeyFuncs(httprate.KeyByIP, httprate.KeyByEndpoint),
))
mux.Use(chimiddleware.RequestID)
- mux.Use(chimiddleware.RealIP)
mux.Use(middleware.ResolveTenantFromRequestMiddleware(o.tenantManager, o.SingleTenantMode))
mux.Use(middleware.LoggingMiddleware)
mux.Use(middleware.RecoverHandler)
@@ -325,9 +324,9 @@ func handleHTTP(o ServeOptions) *chi.Mux {
Patch("/wallets/{receiver_wallet_id}", receiverWalletHandler.RetryInvitation)
})
- r.With(middleware.AnyRoleMiddleware(authManager, data.GetAllRoles()...)).Route("/countries", func(r chi.Router) {
- r.Get("/", httphandler.CountriesHandler{Models: o.Models}.GetCountries)
- })
+ r.
+ With(middleware.AnyRoleMiddleware(authManager, data.GetAllRoles()...)).
+ Get("/registration-contact-types", httphandler.RegistrationContactTypesHandler{}.Get)
r.Route("/assets", func(r chi.Router) {
assetsHandler := httphandler.AssetsHandler{
@@ -347,7 +346,10 @@ func handleHTTP(o ServeOptions) *chi.Mux {
})
r.With(middleware.AnyRoleMiddleware(authManager, data.GetAllRoles()...)).Route("/wallets", func(r chi.Router) {
- walletsHandler := httphandler.WalletsHandler{Models: o.Models}
+ walletsHandler := httphandler.WalletsHandler{
+ Models: o.Models,
+ NetworkType: o.NetworkType,
+ }
r.Get("/", walletsHandler.GetWallets)
r.With(middleware.AnyRoleMiddleware(authManager, data.DeveloperUserRole)).
Post("/", walletsHandler.PostWallets)
@@ -365,6 +367,7 @@ func handleHTTP(o ServeOptions) *chi.Mux {
DistributionAccountResolver: o.SubmitterEngine.DistributionAccountResolver,
PasswordValidator: o.PasswordValidator,
PublicFilesFS: publicfiles.PublicFiles,
+ NetworkType: o.NetworkType,
}
r.Route("/profile", func(r chi.Router) {
r.With(middleware.AnyRoleMiddleware(authManager, data.GetAllRoles()...)).
@@ -396,6 +399,7 @@ func handleHTTP(o ServeOptions) *chi.Mux {
EncryptionPassphrase: o.DistAccEncryptionPassphrase,
CircleClientConfigModel: circle.NewClientConfigModel(o.MtnDBConnectionPool),
DistributionAccountResolver: o.SubmitterEngine.DistributionAccountResolver,
+ MonitorService: o.MonitorService,
}.Patch)
})
@@ -471,7 +475,7 @@ func handleHTTP(o ServeOptions) *chi.Mux {
sep24HeaderTokenAuthenticationMiddleware := anchorplatform.SEP24HeaderTokenAuthenticateMiddleware(o.sep24JWTManager, o.NetworkPassphrase, o.tenantManager, o.SingleTenantMode)
r.With(sep24HeaderTokenAuthenticationMiddleware).Post("/otp", httphandler.ReceiverSendOTPHandler{
Models: o.Models,
- SMSMessengerClient: o.SMSMessengerClient,
+ MessageDispatcher: o.MessageDispatcher,
ReCAPTCHAValidator: reCAPTCHAValidator,
}.ServeHTTP)
r.With(sep24HeaderTokenAuthenticationMiddleware).Post("/verification", httphandler.VerifyReceiverRegistrationHandler{
@@ -486,7 +490,7 @@ func handleHTTP(o ServeOptions) *chi.Mux {
})
// This will be used for test purposes and will only be available when IsPubnet is false:
- r.With(middleware.EnsureTenantMiddleware).Delete("/phone-number/{phone_number}", httphandler.DeletePhoneNumberHandler{
+ r.With(middleware.EnsureTenantMiddleware).Delete("/contact-info/{contact_info}", httphandler.DeleteContactInfoHandler{
Models: o.Models,
NetworkPassphrase: o.NetworkPassphrase,
}.ServeHTTP)
diff --git a/internal/serve/serve_test.go b/internal/serve/serve_test.go
index 98b19b220..84e60ff33 100644
--- a/internal/serve/serve_test.go
+++ b/internal/serve/serve_test.go
@@ -304,6 +304,12 @@ func getServeOptionsForTests(t *testing.T, dbConnectionPool db.DBConnectionPool)
messengerClientMock := message.MessengerClientMock{}
messengerClientMock.On("SendMessage", mock.Anything).Return(nil)
+ messageDispatcherMock := message.NewMockMessageDispatcher(t)
+ messageDispatcherMock.
+ On("SendMessage", mock.Anything, mock.Anything).
+ Return(nil).
+ Maybe()
+
crashTrackerClient, err := crashtracker.NewDryRunClient()
require.NoError(t, err)
@@ -351,7 +357,7 @@ func getServeOptionsForTests(t *testing.T, dbConnectionPool db.DBConnectionPool)
SEP24JWTSecret: "jwt_secret_1234567890",
AnchorPlatformOutgoingJWTSecret: "jwt_secret_1234567890",
AnchorPlatformBasePlatformURL: "https://test.com",
- SMSMessengerClient: &messengerClientMock,
+ MessageDispatcher: messageDispatcherMock,
Version: "x.y.z",
NetworkPassphrase: network.TestNetworkPassphrase,
SubmitterEngine: submitterEngine,
@@ -448,8 +454,8 @@ func Test_handleHTTP_authenticatedEndpoints(t *testing.T) {
{http.MethodPatch, "/receivers/1234"},
{http.MethodPatch, "/receivers/wallets/1234"},
{http.MethodGet, "/receivers/verification-types"},
- // Countries
- {http.MethodGet, "/countries"},
+ // Receiver Contact Types
+ {http.MethodGet, "/registration-contact-types"},
// Assets
{http.MethodGet, "/assets"},
{http.MethodPost, "/assets"},
diff --git a/internal/serve/validators/disbursement_instructions_validator.go b/internal/serve/validators/disbursement_instructions_validator.go
index 2c9e85b45..0d296e519 100644
--- a/internal/serve/validators/disbursement_instructions_validator.go
+++ b/internal/serve/validators/disbursement_instructions_validator.go
@@ -4,63 +4,88 @@ import (
"fmt"
"strings"
+ "github.com/stellar/go/strkey"
+
"github.com/stellar/stellar-disbursement-platform-backend/internal/data"
"github.com/stellar/stellar-disbursement-platform-backend/internal/utils"
)
type DisbursementInstructionsValidator struct {
- verificationField data.VerificationField
+ contactType data.RegistrationContactType
+ verificationField data.VerificationType
*Validator
}
-func NewDisbursementInstructionsValidator(verificationField data.VerificationField) *DisbursementInstructionsValidator {
+func NewDisbursementInstructionsValidator(contactType data.RegistrationContactType, verificationField data.VerificationType) *DisbursementInstructionsValidator {
return &DisbursementInstructionsValidator{
+ contactType: contactType,
verificationField: verificationField,
Validator: NewValidator(),
}
}
func (iv *DisbursementInstructionsValidator) ValidateInstruction(instruction *data.DisbursementInstruction, lineNumber int) {
- phone := strings.TrimSpace(instruction.Phone)
- id := strings.TrimSpace(instruction.ID)
- amount := strings.TrimSpace(instruction.Amount)
- verification := strings.TrimSpace(instruction.VerificationValue)
+ // 1. Validate required fields
+ iv.Check(instruction.ID != "", fmt.Sprintf("line %d - id", lineNumber), "id cannot be empty")
+ iv.CheckError(utils.ValidateAmount(instruction.Amount), fmt.Sprintf("line %d - amount", lineNumber), "invalid amount. Amount must be a positive number")
- // validate phone field
- iv.Check(phone != "", fmt.Sprintf("line %d - phone", lineNumber), "phone cannot be empty")
- if phone != "" {
- iv.CheckError(utils.ValidatePhoneNumber(phone), fmt.Sprintf("line %d - phone", lineNumber), "invalid phone format. Correct format: +380445555555")
+ // 2. Validate Contact fields
+ switch iv.contactType.ReceiverContactType {
+ case data.ReceiverContactTypeEmail:
+ iv.Check(instruction.Email != "", fmt.Sprintf("line %d - email", lineNumber), "email cannot be empty")
+ if instruction.Email != "" {
+ iv.CheckError(utils.ValidateEmail(instruction.Email), fmt.Sprintf("line %d - email", lineNumber), "invalid email format")
+ }
+ case data.ReceiverContactTypeSMS:
+ iv.Check(instruction.Phone != "", fmt.Sprintf("line %d - phone", lineNumber), "phone cannot be empty")
+ if instruction.Phone != "" {
+ iv.CheckError(utils.ValidatePhoneNumber(instruction.Phone), fmt.Sprintf("line %d - phone", lineNumber), "invalid phone format. Correct format: +380445555555")
+ }
}
- // validate id field
- iv.Check(id != "", fmt.Sprintf("line %d - id", lineNumber), "id cannot be empty")
-
- // validate amount field
- iv.CheckError(utils.ValidateAmount(amount), fmt.Sprintf("line %d - amount", lineNumber), "invalid amount. Amount must be a positive number")
-
- // validate verification field
- switch iv.verificationField {
- case data.VerificationFieldDateOfBirth:
- iv.CheckError(utils.ValidateDateOfBirthVerification(verification), fmt.Sprintf("line %d - date of birth", lineNumber), "")
- case data.VerificationFieldYearMonth:
- iv.CheckError(utils.ValidateYearMonthVerification(verification), fmt.Sprintf("line %d - year/month", lineNumber), "")
- case data.VerificationFieldPin:
- iv.CheckError(utils.ValidatePinVerification(verification), fmt.Sprintf("line %d - pin", lineNumber), "")
- case data.VerificationFieldNationalID:
- iv.CheckError(utils.ValidateNationalIDVerification(verification), fmt.Sprintf("line %d - national id", lineNumber), "")
+ // 3. Validate WalletAddress field
+ if iv.contactType.IncludesWalletAddress {
+ iv.Check(instruction.WalletAddress != "", fmt.Sprintf("line %d - wallet address", lineNumber), "wallet address cannot be empty")
+ if instruction.WalletAddress != "" {
+ iv.Check(strkey.IsValidEd25519PublicKey(instruction.WalletAddress), fmt.Sprintf("line %d - wallet address", lineNumber), "invalid wallet address. Must be a valid Stellar public key")
+ }
+ } else {
+ // 4. Validate verification field
+ verification := instruction.VerificationValue
+ switch iv.verificationField {
+ case data.VerificationTypeDateOfBirth:
+ iv.CheckError(utils.ValidateDateOfBirthVerification(verification), fmt.Sprintf("line %d - date of birth", lineNumber), "")
+ case data.VerificationTypeYearMonth:
+ iv.CheckError(utils.ValidateYearMonthVerification(verification), fmt.Sprintf("line %d - year/month", lineNumber), "")
+ case data.VerificationTypePin:
+ iv.CheckError(utils.ValidatePinVerification(verification), fmt.Sprintf("line %d - pin", lineNumber), "")
+ case data.VerificationTypeNationalID:
+ iv.CheckError(utils.ValidateNationalIDVerification(verification), fmt.Sprintf("line %d - national id", lineNumber), "")
+ }
}
}
func (iv *DisbursementInstructionsValidator) SanitizeInstruction(instruction *data.DisbursementInstruction) *data.DisbursementInstruction {
var sanitizedInstruction data.DisbursementInstruction
- sanitizedInstruction.Phone = strings.TrimSpace(instruction.Phone)
+ if instruction.Phone != "" {
+ sanitizedInstruction.Phone = strings.ToLower(strings.TrimSpace(instruction.Phone))
+ }
+
+ if instruction.Email != "" {
+ sanitizedInstruction.Email = strings.ToLower(strings.TrimSpace(instruction.Email))
+ }
+
+ if instruction.WalletAddress != "" {
+ sanitizedInstruction.WalletAddress = strings.ToUpper(strings.TrimSpace(instruction.WalletAddress))
+ }
+
+ if instruction.ExternalPaymentId != "" {
+ sanitizedInstruction.ExternalPaymentId = strings.TrimSpace(instruction.ExternalPaymentId)
+ }
+
sanitizedInstruction.ID = strings.TrimSpace(instruction.ID)
sanitizedInstruction.Amount = strings.TrimSpace(instruction.Amount)
sanitizedInstruction.VerificationValue = strings.TrimSpace(instruction.VerificationValue)
- if instruction.ExternalPaymentId != nil {
- externalPaymentId := strings.TrimSpace(*instruction.ExternalPaymentId)
- sanitizedInstruction.ExternalPaymentId = &externalPaymentId
- }
return &sanitizedInstruction
}
diff --git a/internal/serve/validators/disbursement_instructions_validator_test.go b/internal/serve/validators/disbursement_instructions_validator_test.go
index f577c79a8..f7c18abad 100644
--- a/internal/serve/validators/disbursement_instructions_validator_test.go
+++ b/internal/serve/validators/disbursement_instructions_validator_test.go
@@ -13,35 +13,53 @@ func Test_DisbursementInstructionsValidator_ValidateAndGetInstruction(t *testing
name string
instruction *data.DisbursementInstruction
lineNumber int
- verificationField data.VerificationField
+ contactType data.RegistrationContactType
+ verificationField data.VerificationType
hasErrors bool
expectedErrors map[string]interface{}
}{
{
- name: "error if phone number is empty",
+ name: "error if phone number is empty for Phone contact type",
instruction: &data.DisbursementInstruction{
ID: "123456789",
Amount: "100.5",
VerificationValue: "1990-01-01",
},
lineNumber: 2,
- verificationField: data.VerificationFieldDateOfBirth,
+ contactType: data.RegistrationContactTypePhone,
+ verificationField: data.VerificationTypeDateOfBirth,
hasErrors: true,
expectedErrors: map[string]interface{}{
"line 2 - phone": "phone cannot be empty",
},
},
{
- name: "error with all fields empty (phone, id, amount, date of birth)",
+ name: "error if email is empty for Email contact type",
+ instruction: &data.DisbursementInstruction{
+ ID: "123456789",
+ Amount: "100.5",
+ VerificationValue: "1990-01-01",
+ },
+ lineNumber: 2,
+ contactType: data.RegistrationContactTypeEmail,
+ verificationField: data.VerificationTypeDateOfBirth,
+ hasErrors: true,
+ expectedErrors: map[string]interface{}{
+ "line 2 - email": "email cannot be empty",
+ },
+ },
+ {
+ name: "error with all fields empty (phone, id, amount, verification)",
instruction: &data.DisbursementInstruction{},
lineNumber: 2,
- verificationField: data.VerificationFieldDateOfBirth,
+ contactType: data.RegistrationContactTypePhone,
+ verificationField: data.VerificationTypeDateOfBirth,
hasErrors: true,
expectedErrors: map[string]interface{}{
+ "line 2 - phone": "phone cannot be empty",
"line 2 - amount": "invalid amount. Amount must be a positive number",
- "line 2 - date of birth": "date of birth cannot be empty",
"line 2 - id": "id cannot be empty",
- "line 2 - phone": "phone cannot be empty",
+ "line 2 - date of birth": "date of birth cannot be empty",
},
},
{
@@ -53,7 +71,8 @@ func Test_DisbursementInstructionsValidator_ValidateAndGetInstruction(t *testing
VerificationValue: "1990-01-01",
},
lineNumber: 2,
- verificationField: data.VerificationFieldDateOfBirth,
+ contactType: data.RegistrationContactTypePhone,
+ verificationField: data.VerificationTypeDateOfBirth,
hasErrors: true,
expectedErrors: map[string]interface{}{
"line 2 - phone": "invalid phone format. Correct format: +380445555555",
@@ -68,12 +87,29 @@ func Test_DisbursementInstructionsValidator_ValidateAndGetInstruction(t *testing
VerificationValue: "1990-01-01",
},
lineNumber: 3,
- verificationField: data.VerificationFieldDateOfBirth,
+ contactType: data.RegistrationContactTypePhone,
+ verificationField: data.VerificationTypeDateOfBirth,
hasErrors: true,
expectedErrors: map[string]interface{}{
"line 3 - amount": "invalid amount. Amount must be a positive number",
},
},
+ {
+ name: "error if email is not valid",
+ instruction: &data.DisbursementInstruction{
+ Email: "invalidemail",
+ ID: "123456789",
+ Amount: "100.5",
+ VerificationValue: "1990-01-01",
+ },
+ lineNumber: 3,
+ contactType: data.RegistrationContactTypeEmail,
+ verificationField: data.VerificationTypeDateOfBirth,
+ hasErrors: true,
+ expectedErrors: map[string]interface{}{
+ "line 3 - email": "invalid email format",
+ },
+ },
{
name: "error if amount is not positive",
instruction: &data.DisbursementInstruction{
@@ -83,7 +119,8 @@ func Test_DisbursementInstructionsValidator_ValidateAndGetInstruction(t *testing
VerificationValue: "1990-01-01",
},
lineNumber: 3,
- verificationField: data.VerificationFieldDateOfBirth,
+ contactType: data.RegistrationContactTypePhone,
+ verificationField: data.VerificationTypeDateOfBirth,
hasErrors: true,
expectedErrors: map[string]interface{}{
"line 3 - amount": "invalid amount. Amount must be a positive number",
@@ -98,7 +135,8 @@ func Test_DisbursementInstructionsValidator_ValidateAndGetInstruction(t *testing
VerificationValue: "1990/01/01",
},
lineNumber: 3,
- verificationField: data.VerificationFieldDateOfBirth,
+ contactType: data.RegistrationContactTypePhone,
+ verificationField: data.VerificationTypeDateOfBirth,
hasErrors: true,
expectedErrors: map[string]interface{}{
"line 3 - date of birth": "invalid date of birth format. Correct format: 1990-01-30",
@@ -113,7 +151,8 @@ func Test_DisbursementInstructionsValidator_ValidateAndGetInstruction(t *testing
VerificationValue: "2090-01-01",
},
lineNumber: 3,
- verificationField: data.VerificationFieldDateOfBirth,
+ contactType: data.RegistrationContactTypePhone,
+ verificationField: data.VerificationTypeDateOfBirth,
hasErrors: true,
expectedErrors: map[string]interface{}{
"line 3 - date of birth": "date of birth cannot be in the future",
@@ -128,7 +167,8 @@ func Test_DisbursementInstructionsValidator_ValidateAndGetInstruction(t *testing
VerificationValue: "1990/01",
},
lineNumber: 3,
- verificationField: data.VerificationFieldYearMonth,
+ contactType: data.RegistrationContactTypePhone,
+ verificationField: data.VerificationTypeYearMonth,
hasErrors: true,
expectedErrors: map[string]interface{}{
"line 3 - year/month": "invalid year/month format. Correct format: 1990-12",
@@ -143,7 +183,8 @@ func Test_DisbursementInstructionsValidator_ValidateAndGetInstruction(t *testing
VerificationValue: "2090-01",
},
lineNumber: 3,
- verificationField: data.VerificationFieldYearMonth,
+ contactType: data.RegistrationContactTypePhone,
+ verificationField: data.VerificationTypeYearMonth,
hasErrors: true,
expectedErrors: map[string]interface{}{
"line 3 - year/month": "year/month cannot be in the future",
@@ -158,7 +199,8 @@ func Test_DisbursementInstructionsValidator_ValidateAndGetInstruction(t *testing
VerificationValue: "123",
},
lineNumber: 3,
- verificationField: data.VerificationFieldPin,
+ contactType: data.RegistrationContactTypePhone,
+ verificationField: data.VerificationTypePin,
hasErrors: true,
expectedErrors: map[string]interface{}{
"line 3 - pin": "invalid pin length. Cannot have less than 4 or more than 8 characters in pin",
@@ -173,7 +215,8 @@ func Test_DisbursementInstructionsValidator_ValidateAndGetInstruction(t *testing
VerificationValue: "123456789",
},
lineNumber: 3,
- verificationField: data.VerificationFieldPin,
+ contactType: data.RegistrationContactTypePhone,
+ verificationField: data.VerificationTypePin,
hasErrors: true,
expectedErrors: map[string]interface{}{
"line 3 - pin": "invalid pin length. Cannot have less than 4 or more than 8 characters in pin",
@@ -188,12 +231,44 @@ func Test_DisbursementInstructionsValidator_ValidateAndGetInstruction(t *testing
VerificationValue: "6UZMB56FWTKV4U0PJ21TBR6VOQVYSGIMZG2HW2S0L7EK5K83W78",
},
lineNumber: 3,
- verificationField: data.VerificationFieldNationalID,
+ contactType: data.RegistrationContactTypePhone,
+ verificationField: data.VerificationTypeNationalID,
hasErrors: true,
expectedErrors: map[string]interface{}{
"line 3 - national id": "invalid national id. Cannot have more than 50 characters in national id",
},
},
+ {
+ name: "error when WalletAddress is empty for WalletAddress contact type",
+ instruction: &data.DisbursementInstruction{
+ WalletAddress: "",
+ Phone: "+380445555555",
+ ID: "123456789",
+ Amount: "100.5",
+ },
+ lineNumber: 3,
+ contactType: data.RegistrationContactTypePhoneAndWalletAddress,
+ hasErrors: true,
+ expectedErrors: map[string]interface{}{
+ "line 3 - wallet address": "wallet address cannot be empty",
+ },
+ },
+ {
+ name: "error when WalletAddress is not valid for WalletAddress contact type",
+ instruction: &data.DisbursementInstruction{
+ WalletAddress: "invalidwalletaddress",
+ Phone: "+380445555555",
+ ID: "123456789",
+ Amount: "100.5",
+ },
+ lineNumber: 3,
+ contactType: data.RegistrationContactTypePhoneAndWalletAddress,
+ hasErrors: true,
+ expectedErrors: map[string]interface{}{
+ "line 3 - wallet address": "invalid wallet address. Must be a valid Stellar public key",
+ },
+ },
+
// VALID CASES
{
name: "🎉 successfully validates instructions (DATE_OF_BIRTH)",
@@ -204,7 +279,8 @@ func Test_DisbursementInstructionsValidator_ValidateAndGetInstruction(t *testing
VerificationValue: "1990-01-01",
},
lineNumber: 1,
- verificationField: data.VerificationFieldDateOfBirth,
+ contactType: data.RegistrationContactTypePhone,
+ verificationField: data.VerificationTypeDateOfBirth,
hasErrors: false,
},
{
@@ -216,7 +292,8 @@ func Test_DisbursementInstructionsValidator_ValidateAndGetInstruction(t *testing
VerificationValue: "1990-01",
},
lineNumber: 1,
- verificationField: data.VerificationFieldYearMonth,
+ contactType: data.RegistrationContactTypePhone,
+ verificationField: data.VerificationTypeYearMonth,
hasErrors: false,
},
{
@@ -228,7 +305,8 @@ func Test_DisbursementInstructionsValidator_ValidateAndGetInstruction(t *testing
VerificationValue: "ABCD123",
},
lineNumber: 3,
- verificationField: data.VerificationFieldNationalID,
+ contactType: data.RegistrationContactTypePhone,
+ verificationField: data.VerificationTypeNationalID,
hasErrors: false,
},
{
@@ -240,14 +318,53 @@ func Test_DisbursementInstructionsValidator_ValidateAndGetInstruction(t *testing
VerificationValue: "1234",
},
lineNumber: 3,
- verificationField: data.VerificationFieldPin,
+ contactType: data.RegistrationContactTypePhone,
+ verificationField: data.VerificationTypePin,
hasErrors: false,
},
+ {
+ name: "🎉 successfully validates instructions (Email)",
+ instruction: &data.DisbursementInstruction{
+ Email: "myemail@stellar.org",
+ ID: "123456789",
+ Amount: "100.5",
+ VerificationValue: "1234",
+ },
+ lineNumber: 3,
+ contactType: data.RegistrationContactTypeEmail,
+ verificationField: data.VerificationTypePin,
+ hasErrors: false,
+ },
+ {
+ name: "🎉 successfully validates instructions (Phone)",
+ instruction: &data.DisbursementInstruction{
+ Phone: "+380445555555",
+ ID: "123456789",
+ Amount: "100.5",
+ VerificationValue: "1234",
+ },
+ lineNumber: 3,
+ contactType: data.RegistrationContactTypePhone,
+ verificationField: data.VerificationTypePin,
+ hasErrors: false,
+ },
+ {
+ name: "🎉 successfully validates instructions (WalletAddress)",
+ instruction: &data.DisbursementInstruction{
+ WalletAddress: "GB3SAK22KSTIFQAV5GCDNPW7RTQCWGFDKALBY5KJ3JRF2DLSED3E7PVH",
+ Phone: "+380445555555",
+ ID: "123456789",
+ Amount: "100.5",
+ },
+ lineNumber: 3,
+ contactType: data.RegistrationContactTypePhoneAndWalletAddress,
+ hasErrors: false,
+ },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- iv := NewDisbursementInstructionsValidator(tt.verificationField)
+ iv := NewDisbursementInstructionsValidator(tt.contactType, tt.verificationField)
iv.ValidateInstruction(tt.instruction, tt.lineNumber)
if tt.hasErrors {
@@ -280,7 +397,7 @@ func Test_DisbursementInstructionsValidator_SanitizeInstruction(t *testing.T) {
ID: "123456789",
Amount: "100.5",
VerificationValue: "1990-01-01",
- ExternalPaymentId: nil,
+ ExternalPaymentId: "",
},
},
{
@@ -290,21 +407,36 @@ func Test_DisbursementInstructionsValidator_SanitizeInstruction(t *testing.T) {
ID: " 123456789 ",
Amount: " 100.5 ",
VerificationValue: " 1990-01-01 ",
- ExternalPaymentId: &externalPaymentIDWithSpaces,
+ ExternalPaymentId: externalPaymentIDWithSpaces,
},
expectedInstruction: &data.DisbursementInstruction{
Phone: "+380445555555",
ID: "123456789",
Amount: "100.5",
VerificationValue: "1990-01-01",
- ExternalPaymentId: &externalPaymentID,
+ ExternalPaymentId: externalPaymentID,
+ },
+ },
+ {
+ name: "Sanitized instruction with email",
+ actual: &data.DisbursementInstruction{
+ Email: " MyEmail@stellar.org ",
+ ID: " 123456789 ",
+ Amount: " 100.5 ",
+ VerificationValue: " 1990-01-01 ",
+ },
+ expectedInstruction: &data.DisbursementInstruction{
+ Email: "myemail@stellar.org",
+ ID: "123456789",
+ Amount: "100.5",
+ VerificationValue: "1990-01-01",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- iv := NewDisbursementInstructionsValidator(data.VerificationFieldDateOfBirth)
+ iv := NewDisbursementInstructionsValidator(data.RegistrationContactTypePhone, data.VerificationTypeDateOfBirth)
sanitizedInstruction := iv.SanitizeInstruction(tt.actual)
assert.Equal(t, tt.expectedInstruction, sanitizedInstruction)
diff --git a/internal/serve/validators/disbursement_request_validator.go b/internal/serve/validators/disbursement_request_validator.go
deleted file mode 100644
index f93721308..000000000
--- a/internal/serve/validators/disbursement_request_validator.go
+++ /dev/null
@@ -1,29 +0,0 @@
-package validators
-
-import (
- "fmt"
- "slices"
-
- "github.com/stellar/stellar-disbursement-platform-backend/internal/data"
-)
-
-type DisbursementRequestValidator struct {
- verificationField data.VerificationField
- *Validator
-}
-
-func NewDisbursementRequestValidator(verificationField data.VerificationField) *DisbursementRequestValidator {
- return &DisbursementRequestValidator{
- verificationField: verificationField,
- Validator: NewValidator(),
- }
-}
-
-// ValidateAndGetVerificationType validates if the verification type field is a valid value.
-func (dv *DisbursementRequestValidator) ValidateAndGetVerificationType() data.VerificationField {
- if !slices.Contains(data.GetAllVerificationFields(), dv.verificationField) {
- dv.Check(false, "verification_field", fmt.Sprintf("invalid parameter. valid values are: %v", data.GetAllVerificationFields()))
- return ""
- }
- return dv.verificationField
-}
diff --git a/internal/serve/validators/disbursement_request_validator_test.go b/internal/serve/validators/disbursement_request_validator_test.go
deleted file mode 100644
index 8d65be8cf..000000000
--- a/internal/serve/validators/disbursement_request_validator_test.go
+++ /dev/null
@@ -1,34 +0,0 @@
-package validators
-
-import (
- "testing"
-
- "github.com/stretchr/testify/assert"
-
- "github.com/stellar/stellar-disbursement-platform-backend/internal/data"
-)
-
-func Test_DisbursementRequestValidator_ValidateAndGetVerificationType(t *testing.T) {
- t.Run("Valid verification type", func(t *testing.T) {
- validField := []data.VerificationField{
- data.VerificationFieldDateOfBirth,
- data.VerificationFieldYearMonth,
- data.VerificationFieldPin,
- data.VerificationFieldNationalID,
- }
- for _, field := range validField {
- validator := NewDisbursementRequestValidator(field)
- assert.Equal(t, field, validator.ValidateAndGetVerificationType())
- }
- })
-
- t.Run("Invalid verification type", func(t *testing.T) {
- field := data.VerificationField("field")
- validator := NewDisbursementRequestValidator(field)
-
- actual := validator.ValidateAndGetVerificationType()
- assert.Empty(t, actual)
- assert.Equal(t, 1, len(validator.Errors))
- assert.Equal(t, "invalid parameter. valid values are: [DATE_OF_BIRTH YEAR_MONTH PIN NATIONAL_ID_NUMBER]", validator.Errors["verification_field"])
- })
-}
diff --git a/internal/serve/validators/mock.go b/internal/serve/validators/mock.go
index a799f56bc..e0d4b93d6 100644
--- a/internal/serve/validators/mock.go
+++ b/internal/serve/validators/mock.go
@@ -14,3 +14,19 @@ func (v *ReCAPTCHAValidatorMock) IsTokenValid(ctx context.Context, token string)
args := v.Called(ctx, token)
return args.Bool(0), args.Error(1)
}
+
+type testInterface interface {
+ mock.TestingT
+ Cleanup(func())
+}
+
+// NewReCAPTCHAValidatorMock creates a new instance of ReCAPTCHAValidatorMock. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
+// The first argument is typically a *testing.T value.
+func NewReCAPTCHAValidatorMock(t testInterface) *ReCAPTCHAValidatorMock {
+ mock := &ReCAPTCHAValidatorMock{}
+ mock.Mock.Test(t)
+
+ t.Cleanup(func() { mock.AssertExpectations(t) })
+
+ return mock
+}
diff --git a/internal/serve/validators/receiver_registration_validator.go b/internal/serve/validators/receiver_registration_validator.go
index 1a92f0455..10ca64a9f 100644
--- a/internal/serve/validators/receiver_registration_validator.go
+++ b/internal/serve/validators/receiver_registration_validator.go
@@ -23,46 +23,57 @@ func NewReceiverRegistrationValidator() *ReceiverRegistrationValidator {
// ValidateReceiver validates if the infos present in the ReceiverRegistrationRequest are valids.
func (rv *ReceiverRegistrationValidator) ValidateReceiver(receiverInfo *data.ReceiverRegistrationRequest) {
- phone := strings.TrimSpace(receiverInfo.PhoneNumber)
+ phone := utils.TrimAndLower(receiverInfo.PhoneNumber)
+ email := utils.TrimAndLower(receiverInfo.Email)
otp := strings.TrimSpace(receiverInfo.OTP)
verification := strings.TrimSpace(receiverInfo.VerificationValue)
- verificationType := strings.TrimSpace(string(receiverInfo.VerificationType))
+ verificationField := strings.TrimSpace(string(receiverInfo.VerificationField))
- // validate phone field
- rv.CheckError(utils.ValidatePhoneNumber(phone), "phone_number", "invalid phone format. Correct format: +380445555555")
- rv.Check(phone != "", "phone_number", "phone cannot be empty")
+ switch {
+ case phone == "" && email == "":
+ rv.Check(false, "phone_number", "phone_number or email is required")
+ rv.Check(false, "email", "phone_number or email is required")
+ case phone != "" && email != "":
+ rv.Check(false, "phone_number", "phone_number and email cannot be both provided")
+ rv.Check(false, "email", "phone_number and email cannot be both provided")
+ case phone != "":
+ rv.CheckError(utils.ValidatePhoneNumber(phone), "phone_number", "")
+ case email != "":
+ rv.CheckError(utils.ValidateEmail(email), "email", "")
+ }
// validate otp field
rv.CheckError(utils.ValidateOTP(otp), "otp", "invalid otp format. Needs to be a 6 digit value")
// validate verification type field
- rv.Check(verificationType != "", "verification_type", "verification type cannot be empty")
- vt := rv.validateAndGetVerificationType(verificationType)
+ rv.Check(verificationField != "", "verification_field", "verification type cannot be empty")
+ vf := rv.validateAndGetVerificationType(verificationField)
// validate verification fields
- switch vt {
- case data.VerificationFieldDateOfBirth:
+ switch vf {
+ case data.VerificationTypeDateOfBirth:
rv.CheckError(utils.ValidateDateOfBirthVerification(verification), "verification", "")
- case data.VerificationFieldYearMonth:
+ case data.VerificationTypeYearMonth:
rv.CheckError(utils.ValidateYearMonthVerification(verification), "verification", "")
- case data.VerificationFieldPin:
+ case data.VerificationTypePin:
rv.CheckError(utils.ValidatePinVerification(verification), "verification", "")
- case data.VerificationFieldNationalID:
+ case data.VerificationTypeNationalID:
rv.CheckError(utils.ValidateNationalIDVerification(verification), "verification", "")
}
receiverInfo.PhoneNumber = phone
+ receiverInfo.Email = email
receiverInfo.OTP = otp
receiverInfo.VerificationValue = verification
- receiverInfo.VerificationType = vt
+ receiverInfo.VerificationField = vf
}
// validateAndGetVerificationType validates if the verification type field is a valid value.
-func (rv *ReceiverRegistrationValidator) validateAndGetVerificationType(verificationType string) data.VerificationField {
- vt := data.VerificationField(strings.ToUpper(verificationType))
+func (rv *ReceiverRegistrationValidator) validateAndGetVerificationType(verificationType string) data.VerificationType {
+ vt := data.VerificationType(strings.ToUpper(verificationType))
- if !slices.Contains(data.GetAllVerificationFields(), vt) {
- rv.Check(false, "verification_type", fmt.Sprintf("invalid parameter. valid values are: %v", data.GetAllVerificationFields()))
+ if !slices.Contains(data.GetAllVerificationTypes(), vt) {
+ rv.Check(false, "verification_field", fmt.Sprintf("invalid parameter. valid values are: %v", data.GetAllVerificationTypes()))
return ""
}
return vt
diff --git a/internal/serve/validators/receiver_registration_validator_test.go b/internal/serve/validators/receiver_registration_validator_test.go
index b338f3bd7..84c7d6817 100644
--- a/internal/serve/validators/receiver_registration_validator_test.go
+++ b/internal/serve/validators/receiver_registration_validator_test.go
@@ -10,12 +10,10 @@ import (
func Test_ReceiverRegistrationValidator_ValidateReceiver(t *testing.T) {
type testCase struct {
- name string
- receiverInfo data.ReceiverRegistrationRequest
- expectedErrorLen int
- expectedErrorMsg string
- expectedErrorKey string
- expectedReceiver data.ReceiverRegistrationRequest
+ name string
+ receiverInfo data.ReceiverRegistrationRequest
+ expectedReceiver data.ReceiverRegistrationRequest
+ expectedValidationErrors map[string]interface{}
}
testCases := []testCase{
@@ -25,23 +23,50 @@ func Test_ReceiverRegistrationValidator_ValidateReceiver(t *testing.T) {
PhoneNumber: "invalid",
OTP: "123456",
VerificationValue: "1990-01-01",
- VerificationType: data.VerificationFieldDateOfBirth,
+ VerificationField: data.VerificationTypeDateOfBirth,
+ },
+ expectedValidationErrors: map[string]interface{}{
+ "phone_number": "the provided phone number is not a valid E.164 number",
+ },
+ },
+ {
+ name: "error if email is invalid",
+ receiverInfo: data.ReceiverRegistrationRequest{
+ Email: "invalid",
+ OTP: "123456",
+ VerificationValue: "1990-01-01",
+ VerificationField: data.VerificationTypeDateOfBirth,
+ },
+ expectedValidationErrors: map[string]interface{}{
+ "email": "the provided email is not valid",
},
- expectedErrorLen: 1,
- expectedErrorMsg: "invalid phone format. Correct format: +380445555555",
- expectedErrorKey: "phone_number",
},
{
- name: "error if phone number is empty",
+ name: "error if phone number and email are empty",
receiverInfo: data.ReceiverRegistrationRequest{
PhoneNumber: "",
OTP: "123456",
VerificationValue: "1990-01-01",
- VerificationType: data.VerificationFieldDateOfBirth,
+ VerificationField: data.VerificationTypeDateOfBirth,
+ },
+ expectedValidationErrors: map[string]interface{}{
+ "phone_number": "phone_number or email is required",
+ "email": "phone_number or email is required",
+ },
+ },
+ {
+ name: "error if phone number and email are provided",
+ receiverInfo: data.ReceiverRegistrationRequest{
+ Email: "test@stellar.com",
+ PhoneNumber: "+380445555555",
+ OTP: "123456",
+ VerificationValue: "1990-01-01",
+ VerificationField: data.VerificationTypeDateOfBirth,
+ },
+ expectedValidationErrors: map[string]interface{}{
+ "phone_number": "phone_number and email cannot be both provided",
+ "email": "phone_number and email cannot be both provided",
},
- expectedErrorLen: 1,
- expectedErrorMsg: "phone cannot be empty",
- expectedErrorKey: "phone_number",
},
{
name: "error if OTP is invalid",
@@ -49,11 +74,11 @@ func Test_ReceiverRegistrationValidator_ValidateReceiver(t *testing.T) {
PhoneNumber: "+380445555555",
OTP: "12mock",
VerificationValue: "1990-01-01",
- VerificationType: data.VerificationFieldDateOfBirth,
+ VerificationField: data.VerificationTypeDateOfBirth,
+ },
+ expectedValidationErrors: map[string]interface{}{
+ "otp": "invalid otp format. Needs to be a 6 digit value",
},
- expectedErrorLen: 1,
- expectedErrorMsg: "invalid otp format. Needs to be a 6 digit value",
- expectedErrorKey: "otp",
},
{
name: "error if verification type is invalid",
@@ -61,11 +86,11 @@ func Test_ReceiverRegistrationValidator_ValidateReceiver(t *testing.T) {
PhoneNumber: "+380445555555",
OTP: "123456",
VerificationValue: "1990-01-01",
- VerificationType: "mock_type",
+ VerificationField: "mock_type",
+ },
+ expectedValidationErrors: map[string]interface{}{
+ "verification_field": "invalid parameter. valid values are: [DATE_OF_BIRTH YEAR_MONTH PIN NATIONAL_ID_NUMBER]",
},
- expectedErrorLen: 1,
- expectedErrorMsg: "invalid parameter. valid values are: [DATE_OF_BIRTH YEAR_MONTH PIN NATIONAL_ID_NUMBER]",
- expectedErrorKey: "verification_type",
},
{
name: "error if verification[DATE_OF_BIRTH] is invalid",
@@ -73,11 +98,11 @@ func Test_ReceiverRegistrationValidator_ValidateReceiver(t *testing.T) {
PhoneNumber: "+380445555555",
OTP: "123456",
VerificationValue: "90/01/01",
- VerificationType: data.VerificationFieldDateOfBirth,
+ VerificationField: data.VerificationTypeDateOfBirth,
+ },
+ expectedValidationErrors: map[string]interface{}{
+ "verification": "invalid date of birth format. Correct format: 1990-01-30",
},
- expectedErrorLen: 1,
- expectedErrorMsg: "invalid date of birth format. Correct format: 1990-01-30",
- expectedErrorKey: "verification",
},
{
name: "error if verification[YEAR_MONTH] is invalid",
@@ -85,11 +110,11 @@ func Test_ReceiverRegistrationValidator_ValidateReceiver(t *testing.T) {
PhoneNumber: "+380445555555",
OTP: "123456",
VerificationValue: "90/12",
- VerificationType: data.VerificationFieldYearMonth,
+ VerificationField: data.VerificationTypeYearMonth,
+ },
+ expectedValidationErrors: map[string]interface{}{
+ "verification": "invalid year/month format. Correct format: 1990-12",
},
- expectedErrorLen: 1,
- expectedErrorMsg: "invalid year/month format. Correct format: 1990-12",
- expectedErrorKey: "verification",
},
{
name: "error if verification[PIN] is invalid",
@@ -97,11 +122,11 @@ func Test_ReceiverRegistrationValidator_ValidateReceiver(t *testing.T) {
PhoneNumber: "+380445555555",
OTP: "123456",
VerificationValue: "ABCDE1234",
- VerificationType: data.VerificationFieldPin,
+ VerificationField: data.VerificationTypePin,
+ },
+ expectedValidationErrors: map[string]interface{}{
+ "verification": "invalid pin length. Cannot have less than 4 or more than 8 characters in pin",
},
- expectedErrorLen: 1,
- expectedErrorMsg: "invalid pin length. Cannot have less than 4 or more than 8 characters in pin",
- expectedErrorKey: "verification",
},
{
name: "error if verification[NATIONAL_ID_NUMBER] is invalid",
@@ -109,11 +134,11 @@ func Test_ReceiverRegistrationValidator_ValidateReceiver(t *testing.T) {
PhoneNumber: "+380445555555",
OTP: "123456",
VerificationValue: "6UZMB56FWTKV4U0PJ21TBR6VOQVYSGIMZG2HW2S0L7EK5K83W78XXXXX",
- VerificationType: data.VerificationFieldNationalID,
+ VerificationField: data.VerificationTypeNationalID,
+ },
+ expectedValidationErrors: map[string]interface{}{
+ "verification": "invalid national id. Cannot have more than 50 characters in national id",
},
- expectedErrorLen: 1,
- expectedErrorMsg: "invalid national id. Cannot have more than 50 characters in national id",
- expectedErrorKey: "verification",
},
{
name: "🎉 successfully validates receiver values [DATE_OF_BIRTH]",
@@ -121,14 +146,14 @@ func Test_ReceiverRegistrationValidator_ValidateReceiver(t *testing.T) {
PhoneNumber: "+380445555555 ",
OTP: " 123456 ",
VerificationValue: "1990-01-01 ",
- VerificationType: "date_of_birth",
+ VerificationField: "date_of_birth",
},
- expectedErrorLen: 0,
+ expectedValidationErrors: map[string]interface{}{},
expectedReceiver: data.ReceiverRegistrationRequest{
PhoneNumber: "+380445555555",
OTP: "123456",
VerificationValue: "1990-01-01",
- VerificationType: data.VerificationFieldDateOfBirth,
+ VerificationField: data.VerificationTypeDateOfBirth,
},
},
{
@@ -137,14 +162,14 @@ func Test_ReceiverRegistrationValidator_ValidateReceiver(t *testing.T) {
PhoneNumber: "+380445555555 ",
OTP: " 123456 ",
VerificationValue: "1990-12 ",
- VerificationType: "year_month",
+ VerificationField: "year_month",
},
- expectedErrorLen: 0,
+ expectedValidationErrors: map[string]interface{}{},
expectedReceiver: data.ReceiverRegistrationRequest{
PhoneNumber: "+380445555555",
OTP: "123456",
VerificationValue: "1990-12",
- VerificationType: data.VerificationFieldYearMonth,
+ VerificationField: data.VerificationTypeYearMonth,
},
},
{
@@ -153,14 +178,14 @@ func Test_ReceiverRegistrationValidator_ValidateReceiver(t *testing.T) {
PhoneNumber: "+380445555555 ",
OTP: " 123456 ",
VerificationValue: "1234 ",
- VerificationType: "pin",
+ VerificationField: "pin",
},
- expectedErrorLen: 0,
+ expectedValidationErrors: map[string]interface{}{},
expectedReceiver: data.ReceiverRegistrationRequest{
PhoneNumber: "+380445555555",
OTP: "123456",
VerificationValue: "1234",
- VerificationType: data.VerificationFieldPin,
+ VerificationField: data.VerificationTypePin,
},
},
{
@@ -169,14 +194,14 @@ func Test_ReceiverRegistrationValidator_ValidateReceiver(t *testing.T) {
PhoneNumber: "+380445555555 ",
OTP: " 123456 ",
VerificationValue: " NATIONALIDNUMBER123",
- VerificationType: "national_id_number",
+ VerificationField: "national_id_number",
},
- expectedErrorLen: 0,
+ expectedValidationErrors: map[string]interface{}{},
expectedReceiver: data.ReceiverRegistrationRequest{
PhoneNumber: "+380445555555",
OTP: "123456",
VerificationValue: "NATIONALIDNUMBER123",
- VerificationType: data.VerificationFieldNationalID,
+ VerificationField: data.VerificationTypeNationalID,
},
},
}
@@ -186,15 +211,14 @@ func Test_ReceiverRegistrationValidator_ValidateReceiver(t *testing.T) {
validator := NewReceiverRegistrationValidator()
validator.ValidateReceiver(&tc.receiverInfo)
- assert.Equal(t, tc.expectedErrorLen, len(validator.Errors))
-
- if tc.expectedErrorLen > 0 {
- assert.Equal(t, tc.expectedErrorMsg, validator.Errors[tc.expectedErrorKey])
+ if len(tc.expectedValidationErrors) > 0 {
+ assert.Equal(t, tc.expectedValidationErrors, validator.Errors)
} else {
+ assert.Equal(t, tc.expectedReceiver.Email, tc.receiverInfo.Email)
assert.Equal(t, tc.expectedReceiver.PhoneNumber, tc.receiverInfo.PhoneNumber)
assert.Equal(t, tc.expectedReceiver.OTP, tc.receiverInfo.OTP)
assert.Equal(t, tc.expectedReceiver.VerificationValue, tc.receiverInfo.VerificationValue)
- assert.Equal(t, tc.expectedReceiver.VerificationType, tc.receiverInfo.VerificationType)
+ assert.Equal(t, tc.expectedReceiver.VerificationField, tc.receiverInfo.VerificationField)
}
})
}
@@ -203,11 +227,11 @@ func Test_ReceiverRegistrationValidator_ValidateReceiver(t *testing.T) {
func Test_ReceiverRegistrationValidator_ValidateAndGetVerificationType(t *testing.T) {
t.Run("Valid verification type", func(t *testing.T) {
validator := NewReceiverRegistrationValidator()
- validField := []data.VerificationField{
- data.VerificationFieldDateOfBirth,
- data.VerificationFieldYearMonth,
- data.VerificationFieldPin,
- data.VerificationFieldNationalID,
+ validField := []data.VerificationType{
+ data.VerificationTypeDateOfBirth,
+ data.VerificationTypeYearMonth,
+ data.VerificationTypePin,
+ data.VerificationTypeNationalID,
}
for _, field := range validField {
assert.Equal(t, field, validator.validateAndGetVerificationType(string(field)))
@@ -221,6 +245,6 @@ func Test_ReceiverRegistrationValidator_ValidateAndGetVerificationType(t *testin
actual := validator.validateAndGetVerificationType(invalidStatus)
assert.Empty(t, actual)
assert.Equal(t, 1, len(validator.Errors))
- assert.Equal(t, "invalid parameter. valid values are: [DATE_OF_BIRTH YEAR_MONTH PIN NATIONAL_ID_NUMBER]", validator.Errors["verification_type"])
+ assert.Equal(t, "invalid parameter. valid values are: [DATE_OF_BIRTH YEAR_MONTH PIN NATIONAL_ID_NUMBER]", validator.Errors["verification_field"])
})
}
diff --git a/internal/serve/validators/receiver_update_validator.go b/internal/serve/validators/receiver_update_validator.go
index b77bf0e42..b68a73397 100644
--- a/internal/serve/validators/receiver_update_validator.go
+++ b/internal/serve/validators/receiver_update_validator.go
@@ -7,11 +7,14 @@ import (
)
type UpdateReceiverRequest struct {
+ // receiver_verifications fields:
DateOfBirth string `json:"date_of_birth"`
YearMonth string `json:"year_month"`
Pin string `json:"pin"`
NationalID string `json:"national_id"`
+ // receivers fields:
Email string `json:"email"`
+ PhoneNumber string `json:"phone_number"`
ExternalID string `json:"external_id"`
}
type UpdateReceiverValidator struct {
@@ -60,8 +63,12 @@ func (ur *UpdateReceiverValidator) ValidateReceiver(updateReceiverRequest *Updat
ur.Check(utils.ValidateEmail(email) == nil, "email", "invalid email format")
}
+ if updateReceiverRequest.PhoneNumber != "" {
+ ur.Check(utils.ValidatePhoneNumber(updateReceiverRequest.PhoneNumber) == nil, "phone_number", "invalid phone number format")
+ }
+
if updateReceiverRequest.ExternalID != "" {
- ur.Check(externalID != "", "external_id", "invalid external_id format")
+ ur.Check(externalID != "", "external_id", "external_id cannot be set to empty")
}
updateReceiverRequest.DateOfBirth = dateOfBirth
diff --git a/internal/serve/validators/receiver_update_validator_test.go b/internal/serve/validators/receiver_update_validator_test.go
index f1960b338..3b7c3a51c 100644
--- a/internal/serve/validators/receiver_update_validator_test.go
+++ b/internal/serve/validators/receiver_update_validator_test.go
@@ -6,100 +6,107 @@ import (
"github.com/stretchr/testify/assert"
)
-func Test_UpdateReceiverValidator_ValidateReceiver(t *testing.T) {
- t.Run("Empty request", func(t *testing.T) {
- validator := NewUpdateReceiverValidator()
-
- receiverInfo := UpdateReceiverRequest{}
- validator.ValidateReceiver(&receiverInfo)
-
- assert.Equal(t, 1, len(validator.Errors))
- assert.Equal(t, "request body is empty", validator.Errors["body"])
- })
-
- t.Run("Invalid date of birth", func(t *testing.T) {
- validator := NewUpdateReceiverValidator()
-
- receiverInfo := UpdateReceiverRequest{
- DateOfBirth: "invalid",
- }
- validator.ValidateReceiver(&receiverInfo)
-
- assert.Equal(t, 1, len(validator.Errors))
- assert.Equal(t, "invalid date of birth format. Correct format: 1990-01-30", validator.Errors["date_of_birth"])
- })
-
- t.Run("Invalid pin", func(t *testing.T) {
- validator := NewUpdateReceiverValidator()
-
- receiverInfo := UpdateReceiverRequest{
- Pin: " ",
- }
- validator.ValidateReceiver(&receiverInfo)
-
- assert.Equal(t, 1, len(validator.Errors))
- assert.Equal(t, "invalid pin length. Cannot have less than 4 or more than 8 characters in pin", validator.Errors["pin"])
- })
-
- t.Run("Invalid national ID", func(t *testing.T) {
- validator := NewUpdateReceiverValidator()
-
- receiverInfo := UpdateReceiverRequest{
- NationalID: " ",
- }
- validator.ValidateReceiver(&receiverInfo)
-
- assert.Equal(t, 1, len(validator.Errors))
- assert.Equal(t, "national id cannot be empty", validator.Errors["national_id"])
- })
-
- t.Run("invalid email", func(t *testing.T) {
- validator := NewUpdateReceiverValidator()
-
- receiverInfo := UpdateReceiverRequest{
- Email: "invalid",
- }
- validator.ValidateReceiver(&receiverInfo)
-
- assert.Equal(t, 1, len(validator.Errors))
- assert.Equal(t, "invalid email format", validator.Errors["email"])
-
- receiverInfo = UpdateReceiverRequest{
- Email: " ",
- }
- validator.ValidateReceiver(&receiverInfo)
-
- assert.Equal(t, 1, len(validator.Errors))
- assert.Equal(t, "invalid email format", validator.Errors["email"])
- })
-
- t.Run("invalid external ID", func(t *testing.T) {
- validator := NewUpdateReceiverValidator()
-
- receiverInfo := UpdateReceiverRequest{
- ExternalID: " ",
- }
- validator.ValidateReceiver(&receiverInfo)
-
- assert.Equal(t, 1, len(validator.Errors))
- assert.Equal(t, "invalid external_id format", validator.Errors["external_id"])
- })
-
- t.Run("Valid receiver values", func(t *testing.T) {
- validator := NewUpdateReceiverValidator()
-
- receiverInfo := UpdateReceiverRequest{
- DateOfBirth: "1999-01-01",
- Pin: "1234 ",
- NationalID: " 12345CODE",
- Email: "receiver@email.com",
- ExternalID: "externalID",
- }
- validator.ValidateReceiver(&receiverInfo)
-
- assert.Equal(t, 0, len(validator.Errors))
- assert.Equal(t, "1999-01-01", receiverInfo.DateOfBirth)
- assert.Equal(t, "1234", receiverInfo.Pin)
- assert.Equal(t, "12345CODE", receiverInfo.NationalID)
- })
+func Test_UpdateReceiverValidator_ValidateReceiver2(t *testing.T) {
+ testCases := []struct {
+ name string
+ request UpdateReceiverRequest
+ expectedErrors map[string]interface{}
+ }{
+ {
+ name: "Empty request",
+ request: UpdateReceiverRequest{},
+ expectedErrors: map[string]interface{}{
+ "body": "request body is empty",
+ },
+ },
+ {
+ name: "[DATE_OF_BIRTH] ValidationField is invalid",
+ request: UpdateReceiverRequest{
+ DateOfBirth: "invalid",
+ },
+ expectedErrors: map[string]interface{}{
+ "date_of_birth": "invalid date of birth format. Correct format: 1990-01-30",
+ },
+ },
+ {
+ name: "[YEAR_MONTH] ValidationField is invalid",
+ request: UpdateReceiverRequest{
+ YearMonth: "invalid",
+ },
+ expectedErrors: map[string]interface{}{
+ "year_month": "invalid year/month format. Correct format: 1990-12",
+ },
+ },
+ {
+ name: "[PIN] ValidationField is invalid",
+ request: UpdateReceiverRequest{
+ Pin: " ",
+ },
+ expectedErrors: map[string]interface{}{
+ "pin": "invalid pin length. Cannot have less than 4 or more than 8 characters in pin",
+ },
+ },
+ {
+ name: "[NATIONAL_ID_NUMBER] ValidationField is invalid",
+ request: UpdateReceiverRequest{
+ NationalID: " ",
+ },
+ expectedErrors: map[string]interface{}{
+ "national_id": "national id cannot be empty",
+ },
+ },
+ {
+ name: "e-mail is invalid",
+ request: UpdateReceiverRequest{
+ Email: "invalid",
+ },
+ expectedErrors: map[string]interface{}{
+ "email": "invalid email format",
+ },
+ },
+ {
+ name: "phone number is invalid",
+ request: UpdateReceiverRequest{
+ PhoneNumber: "invalid",
+ },
+ expectedErrors: map[string]interface{}{
+ "phone_number": "invalid phone number format",
+ },
+ },
+ {
+ name: "external ID is invalid",
+ request: UpdateReceiverRequest{
+ ExternalID: " ",
+ },
+ expectedErrors: map[string]interface{}{
+ "external_id": "external_id cannot be set to empty",
+ },
+ },
+ {
+ name: "🎉 Valid receiver values",
+ request: UpdateReceiverRequest{
+ DateOfBirth: "1999-01-01",
+ YearMonth: "1999-01",
+ Pin: "1234 ",
+ NationalID: " 12345CODE",
+ Email: "receiver@email.com",
+ PhoneNumber: "+14155556666",
+ ExternalID: "externalID",
+ },
+ expectedErrors: map[string]interface{}{},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ validator := NewUpdateReceiverValidator()
+ validator.ValidateReceiver(&tc.request)
+
+ assert.Equal(t, len(tc.expectedErrors), len(validator.Errors))
+ assert.Equal(t, tc.expectedErrors, validator.Errors)
+ for key, value := range tc.expectedErrors {
+ assert.Equal(t, value, validator.Errors[key])
+ }
+ })
+ }
}
diff --git a/internal/serve/validators/wallet_validator.go b/internal/serve/validators/wallet_validator.go
index 570050660..dcd34239d 100644
--- a/internal/serve/validators/wallet_validator.go
+++ b/internal/serve/validators/wallet_validator.go
@@ -6,6 +6,8 @@ import (
"strings"
"github.com/stellar/go/support/log"
+
+ "github.com/stellar/stellar-disbursement-platform-backend/internal/utils"
)
type WalletRequest struct {
@@ -28,13 +30,14 @@ func NewWalletValidator() *WalletValidator {
return &WalletValidator{Validator: NewValidator()}
}
-func (wv *WalletValidator) ValidateCreateWalletRequest(ctx context.Context, reqBody *WalletRequest) *WalletRequest {
+func (wv *WalletValidator) ValidateCreateWalletRequest(ctx context.Context, reqBody *WalletRequest, enforceHTTPS bool) *WalletRequest {
+ // empty body validation
wv.Check(reqBody != nil, "body", "request body is empty")
-
if wv.HasErrors() {
return nil
}
+ // empty fields validation
name := strings.TrimSpace(reqBody.Name)
homepage := strings.TrimSpace(reqBody.Homepage)
deepLinkSchema := strings.TrimSpace(reqBody.DeepLinkSchema)
@@ -45,15 +48,21 @@ func (wv *WalletValidator) ValidateCreateWalletRequest(ctx context.Context, reqB
wv.Check(deepLinkSchema != "", "deep_link_schema", "deep_link_schema is required")
wv.Check(sep10ClientDomain != "", "sep_10_client_domain", "sep_10_client_domain is required")
wv.Check(len(reqBody.AssetsIDs) != 0, "assets_ids", "provide at least one asset ID")
-
if wv.HasErrors() {
return nil
}
+ // fields format validation
homepageURL, err := url.ParseRequestURI(homepage)
if err != nil {
log.Ctx(ctx).Errorf("parsing homepage URL: %v", err)
wv.Check(false, "homepage", "invalid homepage URL provided")
+ } else {
+ schemes := []string{"https"}
+ if !enforceHTTPS {
+ schemes = append(schemes, "http")
+ }
+ wv.CheckError(utils.ValidateURLScheme(homepage, schemes...), "homepage", "")
}
deepLinkSchemaURL, err := url.ParseRequestURI(deepLinkSchema)
@@ -68,14 +77,18 @@ func (wv *WalletValidator) ValidateCreateWalletRequest(ctx context.Context, reqB
wv.Check(false, "sep_10_client_domain", "invalid SEP-10 client domain URL provided")
}
- if wv.HasErrors() {
- return nil
- }
-
sep10Host := sep10URL.Host
if sep10Host == "" {
sep10Host = sep10URL.String()
}
+ if err := utils.ValidateDNS(sep10Host); err != nil {
+ log.Ctx(ctx).Errorf("validating SEP-10 client domain: %v", err)
+ wv.Check(false, "sep_10_client_domain", "invalid SEP-10 client domain provided")
+ }
+
+ if wv.HasErrors() {
+ return nil
+ }
modifiedReq := &WalletRequest{
Name: name,
diff --git a/internal/serve/validators/wallet_validator_test.go b/internal/serve/validators/wallet_validator_test.go
index f6ba31644..fcaf71201 100644
--- a/internal/serve/validators/wallet_validator_test.go
+++ b/internal/serve/validators/wallet_validator_test.go
@@ -4,7 +4,6 @@ import (
"context"
"testing"
- "github.com/stellar/go/support/log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -12,128 +11,112 @@ import (
func TestWalletValidator_ValidateCreateWalletRequest(t *testing.T) {
ctx := context.Background()
- t.Run("returns error when request body is empty", func(t *testing.T) {
- wv := NewWalletValidator()
- wv.ValidateCreateWalletRequest(ctx, nil)
- assert.True(t, wv.HasErrors())
- assert.Equal(t, map[string]interface{}{"body": "request body is empty"}, wv.Errors)
- })
-
- t.Run("returns error when request body has empty fields", func(t *testing.T) {
- wv := NewWalletValidator()
- reqBody := &WalletRequest{}
-
- wv.ValidateCreateWalletRequest(ctx, reqBody)
- assert.True(t, wv.HasErrors())
- assert.Equal(t, map[string]interface{}{
- "deep_link_schema": "deep_link_schema is required",
- "homepage": "homepage is required",
- "name": "name is required",
- "sep_10_client_domain": "sep_10_client_domain is required",
- "assets_ids": "provide at least one asset ID",
- }, wv.Errors)
-
- reqBody.Name = "Wallet Provider"
- wv.Errors = map[string]interface{}{}
- wv.ValidateCreateWalletRequest(ctx, reqBody)
- assert.True(t, wv.HasErrors())
- assert.Equal(t, map[string]interface{}{
- "deep_link_schema": "deep_link_schema is required",
- "homepage": "homepage is required",
- "sep_10_client_domain": "sep_10_client_domain is required",
- "assets_ids": "provide at least one asset ID",
- }, wv.Errors)
- })
-
- t.Run("returns error when homepage/deep link schema has a invalid URL", func(t *testing.T) {
- getEntries := log.DefaultLogger.StartTest(log.ErrorLevel)
-
- wv := NewWalletValidator()
- reqBody := &WalletRequest{
- Name: "Wallet Provider",
- Homepage: "no-schema-homepage.com",
- DeepLinkSchema: "no-schema-deep-link",
- SEP10ClientDomain: "sep-10-client-domain.com",
- AssetsIDs: []string{"asset-id"},
- }
-
- wv.ValidateCreateWalletRequest(ctx, reqBody)
-
- assert.True(t, wv.HasErrors())
-
- assert.Contains(t, wv.Errors, "homepage")
- assert.Equal(t, "invalid homepage URL provided", wv.Errors["homepage"])
-
- assert.Contains(t, wv.Errors, "deep_link_schema")
- assert.Equal(t, "invalid deep link schema provided", wv.Errors["deep_link_schema"])
-
- entries := getEntries()
- require.Len(t, entries, 2)
- assert.Equal(t, `parsing homepage URL: parse "no-schema-homepage.com": invalid URI for request`, entries[0].Message)
- assert.Equal(t, `parsing deep link schema: parse "no-schema-deep-link": invalid URI for request`, entries[1].Message)
- })
-
- t.Run("validates the homepage successfully", func(t *testing.T) {
- wv := NewWalletValidator()
- reqBody := &WalletRequest{
- Name: "Wallet Provider",
- Homepage: "https://homepage.com",
- DeepLinkSchema: "wallet://deeplinkschema/sdp",
- SEP10ClientDomain: "sep-10-client-domain.com",
- AssetsIDs: []string{"asset-id"},
- }
-
- wv.ValidateCreateWalletRequest(ctx, reqBody)
- assert.False(t, wv.HasErrors())
-
- reqBody.Homepage = "http://homepage.com/sdp?redirect=true"
- wv.ValidateCreateWalletRequest(ctx, reqBody)
- assert.False(t, wv.HasErrors())
- assert.Equal(t, map[string]interface{}{}, wv.Errors)
- })
-
- t.Run("validates the deep link schema successfully", func(t *testing.T) {
- wv := NewWalletValidator()
- reqBody := &WalletRequest{
- Name: "Wallet Provider",
- Homepage: "https://homepage.com",
- DeepLinkSchema: "wallet://deeplinkschema/sdp",
- SEP10ClientDomain: "sep-10-client-domain.com",
- AssetsIDs: []string{"asset-id"},
- }
-
- wv.ValidateCreateWalletRequest(ctx, reqBody)
- assert.False(t, wv.HasErrors())
-
- reqBody.DeepLinkSchema = "https://deeplinkschema.com/sdp?redirect=true"
- wv.ValidateCreateWalletRequest(ctx, reqBody)
- assert.False(t, wv.HasErrors())
- })
-
- t.Run("validates the SEP-10 Client Domain successfully", func(t *testing.T) {
- wv := NewWalletValidator()
- reqBody := &WalletRequest{
- Name: "Wallet Provider",
- Homepage: "https://homepage.com",
- DeepLinkSchema: "wallet://deeplinkschema/sdp",
- SEP10ClientDomain: "https://sep-10-client-domain.com",
- AssetsIDs: []string{"asset-id"},
- }
-
- reqBody = wv.ValidateCreateWalletRequest(ctx, reqBody)
- assert.False(t, wv.HasErrors())
- assert.Equal(t, "sep-10-client-domain.com", reqBody.SEP10ClientDomain)
-
- reqBody.SEP10ClientDomain = "https://sep-10-client-domain.com/sdp?redirect=true"
- reqBody = wv.ValidateCreateWalletRequest(ctx, reqBody)
- assert.False(t, wv.HasErrors())
- assert.Equal(t, "sep-10-client-domain.com", reqBody.SEP10ClientDomain)
-
- reqBody.SEP10ClientDomain = "http://localhost:8000"
- reqBody = wv.ValidateCreateWalletRequest(ctx, reqBody)
- assert.False(t, wv.HasErrors())
- assert.Equal(t, "localhost:8000", reqBody.SEP10ClientDomain)
- })
+ testCases := []struct {
+ name string
+ reqBody *WalletRequest
+ expectedErrs map[string]interface{}
+ updateRequestFn func(wr *WalletRequest)
+ enforceHTTPS bool
+ }{
+ {
+ name: "🔴 error when request body is empty",
+ reqBody: nil,
+ expectedErrs: map[string]interface{}{"body": "request body is empty"},
+ },
+ {
+ name: "🔴 error when request body has empty fields",
+ reqBody: &WalletRequest{},
+ expectedErrs: map[string]interface{}{
+ "deep_link_schema": "deep_link_schema is required",
+ "homepage": "homepage is required",
+ "name": "name is required",
+ "sep_10_client_domain": "sep_10_client_domain is required",
+ "assets_ids": "provide at least one asset ID",
+ },
+ },
+ {
+ name: "🔴 error when homepage,deep-link,client-domain are invalid",
+ reqBody: &WalletRequest{
+ Name: "Wallet Provider",
+ Homepage: "no-schema-homepage.com",
+ DeepLinkSchema: "no-schema-deep-link",
+ SEP10ClientDomain: "-invaliddomain",
+ AssetsIDs: []string{"asset-id"},
+ },
+ expectedErrs: map[string]interface{}{
+ "homepage": "invalid homepage URL provided",
+ "deep_link_schema": "invalid deep link schema provided",
+ "sep_10_client_domain": "invalid SEP-10 client domain provided",
+ },
+ },
+ {
+ name: "🟢 successfully validates the homepage,deep-link,client-domain",
+ reqBody: &WalletRequest{
+ Name: "Wallet Provider",
+ Homepage: "https://homepage.com",
+ DeepLinkSchema: "wallet://deeplinkschema/sdp",
+ SEP10ClientDomain: "sep-10-client-domain.com",
+ AssetsIDs: []string{"asset-id"},
+ },
+ expectedErrs: map[string]interface{}{},
+ },
+ {
+ name: "🟢 successfully validates the homepage,deep-link,client-domain with query params",
+ reqBody: &WalletRequest{
+ Name: "Wallet Provider",
+ Homepage: "http://homepage.com/sdp?redirect=true",
+ DeepLinkSchema: "https://deeplinkschema.com/sdp?redirect=true",
+ SEP10ClientDomain: "sep-10-client-domain.com",
+ AssetsIDs: []string{"asset-id"},
+ },
+ expectedErrs: map[string]interface{}{},
+ },
+ {
+ name: "🔴 fails if enforceHttps=true && homepage=http://...",
+ reqBody: &WalletRequest{
+ Name: "Wallet Provider",
+ Homepage: "http://homepage.com/sdp?redirect=true",
+ DeepLinkSchema: "https://deeplinkschema.com/sdp?redirect=true",
+ SEP10ClientDomain: "sep-10-client-domain.com",
+ AssetsIDs: []string{"asset-id"},
+ },
+ expectedErrs: map[string]interface{}{
+ "homepage": "invalid URL scheme is not part of [https]",
+ },
+ enforceHTTPS: true,
+ },
+ {
+ name: "🟢 successfully validates the homepage,deep-link,client-domain and values get sanitized",
+ reqBody: &WalletRequest{
+ Name: "Wallet Provider",
+ Homepage: "https://homepage.com",
+ DeepLinkSchema: "wallet://deeplinkschema/sdp",
+ SEP10ClientDomain: "https://sep-10-client-domain.com",
+ AssetsIDs: []string{"asset-id"},
+ },
+ updateRequestFn: func(wr *WalletRequest) {
+ wr.SEP10ClientDomain = "sep-10-client-domain.com"
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ wv := NewWalletValidator()
+ reqBody := wv.ValidateCreateWalletRequest(ctx, tc.reqBody, tc.enforceHTTPS)
+
+ if len(tc.expectedErrs) == 0 {
+ require.Falsef(t, wv.HasErrors(), "expected no errors, got: %v", wv.Errors)
+ if tc.updateRequestFn != nil {
+ tc.updateRequestFn(tc.reqBody)
+ }
+ assert.Equal(t, tc.reqBody, reqBody)
+ } else {
+ assert.True(t, wv.HasErrors())
+ assert.Equal(t, tc.expectedErrs, wv.Errors)
+ }
+ })
+ }
}
func TestWalletValidator_ValidatePatchWalletRequest(t *testing.T) {
@@ -141,7 +124,7 @@ func TestWalletValidator_ValidatePatchWalletRequest(t *testing.T) {
t.Run("returns error when request body is empty", func(t *testing.T) {
wv := NewWalletValidator()
- wv.ValidateCreateWalletRequest(ctx, nil)
+ wv.ValidateCreateWalletRequest(ctx, nil, false)
assert.True(t, wv.HasErrors())
assert.Equal(t, map[string]interface{}{"body": "request body is empty"}, wv.Errors)
})
diff --git a/internal/services/assets/assets_pubnet.go b/internal/services/assets/assets_pubnet.go
index 8e875a521..3854faede 100644
--- a/internal/services/assets/assets_pubnet.go
+++ b/internal/services/assets/assets_pubnet.go
@@ -2,6 +2,12 @@ package assets
import "github.com/stellar/stellar-disbursement-platform-backend/internal/data"
+var AllAssetsPubnet = []data.Asset{
+ EURCAssetPubnet,
+ USDCAssetPubnet,
+ XLMAsset,
+}
+
// USDC
const USDCAssetCode = "USDC"
diff --git a/internal/services/assets/assets_testnet.go b/internal/services/assets/assets_testnet.go
index 22a702827..32ba5cc3b 100644
--- a/internal/services/assets/assets_testnet.go
+++ b/internal/services/assets/assets_testnet.go
@@ -2,6 +2,11 @@ package assets
import "github.com/stellar/stellar-disbursement-platform-backend/internal/data"
+var AllAssetsTestnet = []data.Asset{
+ XLMAsset,
+ USDCAssetTestnet,
+}
+
// USDC
const USDCAssetIssuerTestnet = "GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5"
diff --git a/internal/services/circle_reconciliation_service_test.go b/internal/services/circle_reconciliation_service_test.go
index c89549b67..26cd2e008 100644
--- a/internal/services/circle_reconciliation_service_test.go
+++ b/internal/services/circle_reconciliation_service_test.go
@@ -165,7 +165,6 @@ func Test_NewCircleReconciliationService_Reconcile_partialSuccess(t *testing.T)
require.NoError(t, err)
asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, assets.EURCAssetCode, assets.EURCAssetTestnet.Issuer)
- country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "FRA", "France")
wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "My Wallet", "https://www.wallet.com", "www.wallet.com", "wallet1://")
// Create distribution accounts
@@ -176,11 +175,10 @@ func Test_NewCircleReconciliationService_Reconcile_partialSuccess(t *testing.T)
}
disbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Name: "disbursement",
- Status: data.StartedDisbursementStatus,
- Asset: asset,
- Wallet: wallet,
- Country: country,
+ Name: "disbursement",
+ Status: data.StartedDisbursementStatus,
+ Asset: asset,
+ Wallet: wallet,
})
receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{})
receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.RegisteredReceiversWalletStatus)
@@ -332,15 +330,13 @@ func Test_NewCircleReconciliationService_reconcileTransferRequest(t *testing.T)
require.NoError(t, err)
asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, assets.EURCAssetCode, assets.EURCAssetTestnet.Issuer)
- country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "FRA", "France")
wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "My Wallet", "https://www.wallet.com", "www.wallet.com", "wallet1://")
disbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Name: "disbursement",
- Status: data.StartedDisbursementStatus,
- Asset: asset,
- Wallet: wallet,
- Country: country,
+ Name: "disbursement",
+ Status: data.StartedDisbursementStatus,
+ Asset: asset,
+ Wallet: wallet,
})
receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{})
receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.RegisteredReceiversWalletStatus)
diff --git a/internal/services/disbursement_management_service.go b/internal/services/disbursement_management_service.go
index ca3ffc498..9df0adf9d 100644
--- a/internal/services/disbursement_management_service.go
+++ b/internal/services/disbursement_management_service.go
@@ -266,12 +266,12 @@ func (s *DisbursementManagementService) StartDisbursement(ctx context.Context, d
}
if len(receiverWallets) != 0 {
- eventData := make([]schemas.EventReceiverWalletSMSInvitationData, 0, len(receiverWallets))
+ eventData := make([]schemas.EventReceiverWalletInvitationData, 0, len(receiverWallets))
for _, receiverWallet := range receiverWallets {
- eventData = append(eventData, schemas.EventReceiverWalletSMSInvitationData{ReceiverWalletID: receiverWallet.ID})
+ eventData = append(eventData, schemas.EventReceiverWalletInvitationData{ReceiverWalletID: receiverWallet.ID})
}
- sendInviteMsg, msgErr := events.NewMessage(ctx, events.ReceiverWalletNewInvitationTopic, disbursement.ID, events.BatchReceiverWalletSMSInvitationType, eventData)
+ sendInviteMsg, msgErr := events.NewMessage(ctx, events.ReceiverWalletNewInvitationTopic, disbursement.ID, events.BatchReceiverWalletInvitationType, eventData)
if msgErr != nil {
return nil, fmt.Errorf("creating new message: %w", msgErr)
}
diff --git a/internal/services/disbursement_management_service_test.go b/internal/services/disbursement_management_service_test.go
index 123e8aab2..a66f34792 100644
--- a/internal/services/disbursement_management_service_test.go
+++ b/internal/services/disbursement_management_service_test.go
@@ -207,10 +207,9 @@ func Test_DisbursementManagementService_StartDisbursement_success(t *testing.T)
// Create models and basic DB entries
models, err := data.NewModels(dbConnectionPool)
require.NoError(t, err)
- // Create fixtures: asset, wallet, country
+ // Create fixtures: asset, wallet
asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, assets.EURCAssetCode, assets.EURCAssetIssuerTestnet)
wallet := data.CreateDefaultWalletFixture(t, ctx, dbConnectionPool)
- country := data.GetCountryFixture(t, ctx, dbConnectionPool, data.FixtureCountryUKR)
// Update context with tenant and auth token
tnt := tenant.Tenant{ID: "tenant-id"}
@@ -337,11 +336,10 @@ func Test_DisbursementManagementService_StartDisbursement_success(t *testing.T)
// Create fixtures: disbursements
readyDisbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Name: "ready disbursement",
- Status: data.ReadyDisbursementStatus,
- Asset: asset,
- Wallet: wallet,
- Country: country,
+ Name: "ready disbursement",
+ Status: data.ReadyDisbursementStatus,
+ Asset: asset,
+ Wallet: wallet,
StatusHistory: []data.DisbursementStatusHistoryEntry{
{UserID: ownerUser.ID, Status: data.DraftDisbursementStatus},
{UserID: ownerUser.ID, Status: data.ReadyDisbursementStatus},
@@ -407,13 +405,13 @@ func Test_DisbursementManagementService_StartDisbursement_success(t *testing.T)
sendInviteMsg := msgs[0]
assert.Equal(t, events.ReceiverWalletNewInvitationTopic, sendInviteMsg.Topic)
assert.Equal(t, readyDisbursement.ID, sendInviteMsg.Key)
- assert.Equal(t, events.BatchReceiverWalletSMSInvitationType, sendInviteMsg.Type)
+ assert.Equal(t, events.BatchReceiverWalletInvitationType, sendInviteMsg.Type)
assert.Equal(t, tnt.ID, sendInviteMsg.TenantID)
- eventData, ok := sendInviteMsg.Data.([]schemas.EventReceiverWalletSMSInvitationData)
+ eventData, ok := sendInviteMsg.Data.([]schemas.EventReceiverWalletInvitationData)
require.True(t, ok)
require.Len(t, eventData, 2)
- wantElements := []schemas.EventReceiverWalletSMSInvitationData{
+ wantElements := []schemas.EventReceiverWalletInvitationData{
{ReceiverWalletID: rwDraft.ID}, // <--- invitation for the receiver that is being included in the system for the first time
{ReceiverWalletID: rwReady.ID}, // <--- invitation for the receiver that is already in the system but doesn't have a Stellar wallet yet
}
@@ -507,22 +505,20 @@ func Test_DisbursementManagementService_StartDisbursement_failure(t *testing.T)
token := "token"
ctx = context.WithValue(ctx, middleware.TokenContextKey, token)
- // Create fixtures: asset, wallet, country
+ // Create fixtures: asset, wallet
asset := data.GetAssetFixture(t, ctx, dbConnectionPool, data.FixtureAssetUSDC)
distributionAccPubKey := "GAAHIL6ZW4QFNLCKALZ3YOIWPP4TXQ7B7J5IU7RLNVGQAV6GFDZHLDTA"
distributionAcc := schema.NewDefaultStellarTransactionAccount(distributionAccPubKey)
// create fixtures
wallet := data.CreateDefaultWalletFixture(t, ctx, dbConnectionPool)
- country := data.GetCountryFixture(t, ctx, dbConnectionPool, data.FixtureCountryUKR)
// Create fixtures: disbursements
draftDisbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Name: "draft disbursement",
- Status: data.DraftDisbursementStatus,
- Asset: asset,
- Wallet: wallet,
- Country: country,
+ Name: "draft disbursement",
+ Status: data.DraftDisbursementStatus,
+ Asset: asset,
+ Wallet: wallet,
})
// Create fixtures: receivers, receiver wallets
@@ -572,11 +568,10 @@ func Test_DisbursementManagementService_StartDisbursement_failure(t *testing.T)
userID := "9ae68f09-cad9-4311-9758-4ff59d2e9e6d"
disbursement := data.CreateDisbursementFixture(t, context.Background(), dbConnectionPool, models.Disbursements, &data.Disbursement{
- Name: "disbursement #1",
- Status: data.ReadyDisbursementStatus,
- Asset: asset,
- Wallet: wallet,
- Country: country,
+ Name: "disbursement #1",
+ Status: data.ReadyDisbursementStatus,
+ Asset: asset,
+ Wallet: wallet,
StatusHistory: []data.DisbursementStatusHistoryEntry{
{
Status: data.DraftDisbursementStatus,
@@ -612,11 +607,10 @@ func Test_DisbursementManagementService_StartDisbursement_failure(t *testing.T)
usdt := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDT", "GBVHJTRLQRMIHRYTXZQOPVYCVVH7IRJN3DOFT7VC6U75CBWWBVDTWURG")
disbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Name: "disbursement - balance insufficient",
- Status: data.StartedDisbursementStatus,
- Asset: usdt,
- Wallet: wallet,
- Country: country,
+ Name: "disbursement - balance insufficient",
+ Status: data.StartedDisbursementStatus,
+ Asset: usdt,
+ Wallet: wallet,
})
// should consider this payment since it's the same asset
data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{
@@ -628,11 +622,10 @@ func Test_DisbursementManagementService_StartDisbursement_failure(t *testing.T)
})
disbursement2 := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Name: "disbursement #4",
- Status: data.StartedDisbursementStatus,
- Asset: asset,
- Wallet: wallet,
- Country: country,
+ Name: "disbursement #4",
+ Status: data.StartedDisbursementStatus,
+ Asset: asset,
+ Wallet: wallet,
})
// should NOT consider this payment since it's NOT the same asset
data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{
@@ -644,11 +637,10 @@ func Test_DisbursementManagementService_StartDisbursement_failure(t *testing.T)
})
disbursementInsufficientBalance := data.CreateDisbursementFixture(t, context.Background(), dbConnectionPool, models.Disbursements, &data.Disbursement{
- Name: "disbursement - insufficient balance",
- Status: data.ReadyDisbursementStatus,
- Asset: usdt,
- Wallet: wallet,
- Country: country,
+ Name: "disbursement - insufficient balance",
+ Status: data.ReadyDisbursementStatus,
+ Asset: usdt,
+ Wallet: wallet,
})
data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{
ReceiverWallet: rwReady,
@@ -724,7 +716,6 @@ func Test_DisbursementManagementService_StartDisbursement_failure(t *testing.T)
Status: data.ReadyDisbursementStatus,
Asset: asset,
Wallet: wallet,
- Country: country,
StatusHistory: statusHistory,
})
@@ -749,8 +740,8 @@ func Test_DisbursementManagementService_StartDisbursement_failure(t *testing.T)
Topic: events.ReceiverWalletNewInvitationTopic,
Key: disbursement.ID,
TenantID: tnt.ID,
- Type: events.BatchReceiverWalletSMSInvitationType,
- Data: []schemas.EventReceiverWalletSMSInvitationData{
+ Type: events.BatchReceiverWalletInvitationType,
+ Data: []schemas.EventReceiverWalletInvitationData{
{ReceiverWalletID: rwReady.ID}, // Receiver that can receive SMS
},
},
@@ -829,7 +820,6 @@ func Test_DisbursementManagementService_StartDisbursement_failure(t *testing.T)
Status: data.ReadyDisbursementStatus,
Asset: asset,
Wallet: wallet,
- Country: country,
StatusHistory: statusHistory,
})
@@ -852,8 +842,8 @@ func Test_DisbursementManagementService_StartDisbursement_failure(t *testing.T)
Topic: events.ReceiverWalletNewInvitationTopic,
Key: disbursement.ID,
TenantID: tnt.ID,
- Type: events.BatchReceiverWalletSMSInvitationType,
- Data: []schemas.EventReceiverWalletSMSInvitationData{
+ Type: events.BatchReceiverWalletInvitationType,
+ Data: []schemas.EventReceiverWalletInvitationData{
{ReceiverWalletID: rwReady.ID}, // Receiver that can receive SMS
},
},
@@ -907,7 +897,6 @@ func Test_DisbursementManagementService_StartDisbursement_failure(t *testing.T)
Status: data.ReadyDisbursementStatus,
Asset: asset,
Wallet: wallet,
- Country: country,
StatusHistory: statusHistory,
})
@@ -961,7 +950,6 @@ func Test_DisbursementManagementService_StartDisbursement_failure(t *testing.T)
Status: data.ReadyDisbursementStatus,
Asset: asset,
Wallet: wallet,
- Country: country,
StatusHistory: statusHistory,
})
@@ -1013,8 +1001,8 @@ func Test_DisbursementManagementService_StartDisbursement_failure(t *testing.T)
Topic: events.ReceiverWalletNewInvitationTopic,
Key: disbursement.ID,
TenantID: tnt.ID,
- Type: events.BatchReceiverWalletSMSInvitationType,
- Data: []schemas.EventReceiverWalletSMSInvitationData{
+ Type: events.BatchReceiverWalletInvitationType,
+ Data: []schemas.EventReceiverWalletInvitationData{
{
ReceiverWalletID: rwReady.ID,
},
@@ -1089,23 +1077,20 @@ func Test_DisbursementManagementService_PauseDisbursement(t *testing.T) {
// create fixtures
wallet := data.CreateDefaultWalletFixture(t, ctx, dbConnectionPool)
- country := data.GetCountryFixture(t, ctx, dbConnectionPool, data.FixtureCountryUSA)
// create disbursements
readyDisbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Name: "ready disbursement",
- Status: data.ReadyDisbursementStatus,
- Asset: asset,
- Wallet: wallet,
- Country: country,
+ Name: "ready disbursement",
+ Status: data.ReadyDisbursementStatus,
+ Asset: asset,
+ Wallet: wallet,
})
startedDisbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Name: "started disbursement",
- Status: data.StartedDisbursementStatus,
- Asset: asset,
- Wallet: wallet,
- Country: country,
+ Name: "started disbursement",
+ Status: data.StartedDisbursementStatus,
+ Asset: asset,
+ Wallet: wallet,
})
// create disbursement receivers
@@ -1354,15 +1339,13 @@ func Test_DisbursementManagementService_validateBalanceForDisbursement(t *testin
models, outerErr := data.NewModels(dbConnectionPool)
require.NoError(t, outerErr)
asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV")
- country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "FRA", "France")
wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "wallet1", "https://www.wallet.com", "www.wallet.com", "wallet1://")
receiverReady := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{})
rwReady := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiverReady.ID, wallet.ID, data.ReadyReceiversWalletStatus)
disbursementOld := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Country: country,
- Wallet: wallet,
- Status: data.ReadyDisbursementStatus,
- Asset: asset,
+ Wallet: wallet,
+ Status: data.ReadyDisbursementStatus,
+ Asset: asset,
})
_ = data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{
ReceiverWallet: rwReady,
@@ -1372,10 +1355,9 @@ func Test_DisbursementManagementService_validateBalanceForDisbursement(t *testin
Status: data.PendingPaymentStatus,
})
disbursementNew := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Country: country,
- Wallet: wallet,
- Status: data.ReadyDisbursementStatus,
- Asset: asset,
+ Wallet: wallet,
+ Status: data.ReadyDisbursementStatus,
+ Asset: asset,
})
_ = data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{
ReceiverWallet: rwReady,
diff --git a/internal/services/mocks/send_receiver_wallets_invite_service.go b/internal/services/mocks/send_receiver_wallets_invite_service.go
index 252c7ce3a..45346092f 100644
--- a/internal/services/mocks/send_receiver_wallets_invite_service.go
+++ b/internal/services/mocks/send_receiver_wallets_invite_service.go
@@ -12,7 +12,7 @@ type MockSendReceiverWalletInviteService struct {
mock.Mock
}
-func (s *MockSendReceiverWalletInviteService) SendInvite(ctx context.Context, receiverWalletsReq ...schemas.EventReceiverWalletSMSInvitationData) error {
+func (s *MockSendReceiverWalletInviteService) SendInvite(ctx context.Context, receiverWalletsReq ...schemas.EventReceiverWalletInvitationData) error {
args := s.Called(ctx, receiverWalletsReq)
return args.Error(0)
}
diff --git a/internal/services/patch_anchor_platform_transactions_completion_test.go b/internal/services/patch_anchor_platform_transactions_completion_test.go
index d078d3a14..544703ab8 100644
--- a/internal/services/patch_anchor_platform_transactions_completion_test.go
+++ b/internal/services/patch_anchor_platform_transactions_completion_test.go
@@ -66,7 +66,6 @@ func Test_PatchAnchorPlatformTransactionCompletionService_PatchAPTransactionForP
t.Run("doesn't patch the transaction when payment isn't on Success or Failed status", func(t *testing.T) {
data.DeleteAllFixtures(t, ctx, dbConnectionPool)
- country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "BRA", "Brazil")
wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "Wallet", "https://www.wallet.com", "www.wallet.com", "wallet://")
asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV")
@@ -74,11 +73,10 @@ func Test_PatchAnchorPlatformTransactionCompletionService_PatchAPTransactionForP
receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.RegisteredReceiversWalletStatus)
disbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Country: country,
Wallet: wallet,
Asset: asset,
Status: data.StartedDisbursementStatus,
- VerificationField: data.VerificationFieldDateOfBirth,
+ VerificationField: data.VerificationTypeDateOfBirth,
})
payment := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{
@@ -103,7 +101,6 @@ func Test_PatchAnchorPlatformTransactionCompletionService_PatchAPTransactionForP
t.Run("doesn't mark as synced when fails patching anchor platform transaction when payment is success", func(t *testing.T) {
data.DeleteAllFixtures(t, ctx, dbConnectionPool)
- country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "BRA", "Brazil")
wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "Wallet", "https://www.wallet.com", "www.wallet.com", "wallet://")
asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV")
@@ -111,11 +108,10 @@ func Test_PatchAnchorPlatformTransactionCompletionService_PatchAPTransactionForP
receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.RegisteredReceiversWalletStatus)
disbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Country: country,
Wallet: wallet,
Asset: asset,
Status: data.StartedDisbursementStatus,
- VerificationField: data.VerificationFieldDateOfBirth,
+ VerificationField: data.VerificationTypeDateOfBirth,
})
payment := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{
@@ -174,7 +170,6 @@ func Test_PatchAnchorPlatformTransactionCompletionService_PatchAPTransactionForP
t.Run("mark as synced when patch anchor platform transaction successfully and payment is failed", func(t *testing.T) {
data.DeleteAllFixtures(t, ctx, dbConnectionPool)
- country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "BRA", "Brazil")
wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "Wallet", "https://www.wallet.com", "www.wallet.com", "wallet://")
asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV")
@@ -182,11 +177,10 @@ func Test_PatchAnchorPlatformTransactionCompletionService_PatchAPTransactionForP
receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.RegisteredReceiversWalletStatus)
disbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Country: country,
Wallet: wallet,
Asset: asset,
Status: data.StartedDisbursementStatus,
- VerificationField: data.VerificationFieldDateOfBirth,
+ VerificationField: data.VerificationTypeDateOfBirth,
})
payment := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{
@@ -228,7 +222,6 @@ func Test_PatchAnchorPlatformTransactionCompletionService_PatchAPTransactionForP
t.Run("marks as synced when patch anchor platform transaction successfully and payment is success", func(t *testing.T) {
data.DeleteAllFixtures(t, ctx, dbConnectionPool)
- country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "BRA", "Brazil")
wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "Wallet", "https://www.wallet.com", "www.wallet.com", "wallet://")
asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV")
@@ -236,11 +229,10 @@ func Test_PatchAnchorPlatformTransactionCompletionService_PatchAPTransactionForP
receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.RegisteredReceiversWalletStatus)
disbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Country: country,
Wallet: wallet,
Asset: asset,
Status: data.StartedDisbursementStatus,
- VerificationField: data.VerificationFieldDateOfBirth,
+ VerificationField: data.VerificationTypeDateOfBirth,
})
payment := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{
@@ -293,7 +285,6 @@ func Test_PatchAnchorPlatformTransactionCompletionService_PatchAPTransactionForP
t.Run("marks as synced when patch anchor platform transaction successfully and payment is success (XLM)", func(t *testing.T) {
data.DeleteAllFixtures(t, ctx, dbConnectionPool)
- country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "BRA", "Brazil")
wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "Wallet", "https://www.wallet.com", "www.wallet.com", "wallet://")
asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "XLM", "")
@@ -301,11 +292,10 @@ func Test_PatchAnchorPlatformTransactionCompletionService_PatchAPTransactionForP
receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.RegisteredReceiversWalletStatus)
disbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Country: country,
Wallet: wallet,
Asset: asset,
Status: data.StartedDisbursementStatus,
- VerificationField: data.VerificationFieldDateOfBirth,
+ VerificationField: data.VerificationTypeDateOfBirth,
})
payment := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{
@@ -358,7 +348,6 @@ func Test_PatchAnchorPlatformTransactionCompletionService_PatchAPTransactionForP
t.Run("doesn't patch the transaction when it's already patch as completed", func(t *testing.T) {
data.DeleteAllFixtures(t, ctx, dbConnectionPool)
- country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "BRA", "Brazil")
wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "Wallet", "https://www.wallet.com", "www.wallet.com", "wallet://")
asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV")
@@ -366,11 +355,10 @@ func Test_PatchAnchorPlatformTransactionCompletionService_PatchAPTransactionForP
receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.RegisteredReceiversWalletStatus)
disbursement1 := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Country: country,
Wallet: wallet,
Asset: asset,
Status: data.StartedDisbursementStatus,
- VerificationField: data.VerificationFieldDateOfBirth,
+ VerificationField: data.VerificationTypeDateOfBirth,
})
payment := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{
@@ -442,7 +430,6 @@ func Test_PatchAnchorPlatformTransactionCompletionService_PatchAPTransactionsFor
t.Run("doesn't mark as synced when fails patching anchor platform transaction when payment is success", func(t *testing.T) {
data.DeleteAllFixtures(t, ctx, dbConnectionPool)
- country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "BRA", "Brazil")
wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "Wallet", "https://www.wallet.com", "www.wallet.com", "wallet://")
asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV")
@@ -450,11 +437,10 @@ func Test_PatchAnchorPlatformTransactionCompletionService_PatchAPTransactionsFor
receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.RegisteredReceiversWalletStatus)
disbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Country: country,
Wallet: wallet,
Asset: asset,
Status: data.StartedDisbursementStatus,
- VerificationField: data.VerificationFieldDateOfBirth,
+ VerificationField: data.VerificationTypeDateOfBirth,
})
payment := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{
@@ -506,7 +492,6 @@ func Test_PatchAnchorPlatformTransactionCompletionService_PatchAPTransactionsFor
t.Run("mark as synced when patch anchor platform transaction successfully and payment is failed", func(t *testing.T) {
data.DeleteAllFixtures(t, ctx, dbConnectionPool)
- country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "BRA", "Brazil")
wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "Wallet", "https://www.wallet.com", "www.wallet.com", "wallet://")
asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV")
@@ -514,11 +499,10 @@ func Test_PatchAnchorPlatformTransactionCompletionService_PatchAPTransactionsFor
receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.RegisteredReceiversWalletStatus)
disbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Country: country,
Wallet: wallet,
Asset: asset,
Status: data.StartedDisbursementStatus,
- VerificationField: data.VerificationFieldDateOfBirth,
+ VerificationField: data.VerificationTypeDateOfBirth,
})
payment := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{
@@ -566,7 +550,6 @@ func Test_PatchAnchorPlatformTransactionCompletionService_PatchAPTransactionsFor
t.Run("marks as synced when patch anchor platform transaction successfully and payment is success", func(t *testing.T) {
data.DeleteAllFixtures(t, ctx, dbConnectionPool)
- country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "BRA", "Brazil")
wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "Wallet", "https://www.wallet.com", "www.wallet.com", "wallet://")
asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV")
@@ -574,11 +557,10 @@ func Test_PatchAnchorPlatformTransactionCompletionService_PatchAPTransactionsFor
receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.RegisteredReceiversWalletStatus)
disbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Country: country,
Wallet: wallet,
Asset: asset,
Status: data.StartedDisbursementStatus,
- VerificationField: data.VerificationFieldDateOfBirth,
+ VerificationField: data.VerificationTypeDateOfBirth,
})
payment := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{
@@ -629,7 +611,6 @@ func Test_PatchAnchorPlatformTransactionCompletionService_PatchAPTransactionsFor
t.Run("doesn't patch the transaction when it's already patch as completed", func(t *testing.T) {
data.DeleteAllFixtures(t, ctx, dbConnectionPool)
- country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "BRA", "Brazil")
wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "Wallet", "https://www.wallet.com", "www.wallet.com", "wallet://")
asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV")
@@ -637,19 +618,17 @@ func Test_PatchAnchorPlatformTransactionCompletionService_PatchAPTransactionsFor
receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.RegisteredReceiversWalletStatus)
disbursement1 := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Country: country,
Wallet: wallet,
Asset: asset,
Status: data.StartedDisbursementStatus,
- VerificationField: data.VerificationFieldDateOfBirth,
+ VerificationField: data.VerificationTypeDateOfBirth,
})
disbursement2 := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Country: country,
Wallet: wallet,
Asset: asset,
Status: data.StartedDisbursementStatus,
- VerificationField: data.VerificationFieldDateOfBirth,
+ VerificationField: data.VerificationTypeDateOfBirth,
})
payment := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{
@@ -713,7 +692,6 @@ func Test_PatchAnchorPlatformTransactionCompletionService_PatchAPTransactionsFor
t.Run("patches the transactions successfully if the other payments were failed", func(t *testing.T) {
data.DeleteAllFixtures(t, ctx, dbConnectionPool)
- country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "BRA", "Brazil")
wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "Wallet", "https://www.wallet.com", "www.wallet.com", "wallet://")
asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV")
@@ -721,27 +699,24 @@ func Test_PatchAnchorPlatformTransactionCompletionService_PatchAPTransactionsFor
receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.RegisteredReceiversWalletStatus)
disbursement1 := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Country: country,
Wallet: wallet,
Asset: asset,
Status: data.StartedDisbursementStatus,
- VerificationField: data.VerificationFieldDateOfBirth,
+ VerificationField: data.VerificationTypeDateOfBirth,
})
disbursement2 := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Country: country,
Wallet: wallet,
Asset: asset,
Status: data.StartedDisbursementStatus,
- VerificationField: data.VerificationFieldDateOfBirth,
+ VerificationField: data.VerificationTypeDateOfBirth,
})
disbursement3 := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Country: country,
Wallet: wallet,
Asset: asset,
Status: data.StartedDisbursementStatus,
- VerificationField: data.VerificationFieldDateOfBirth,
+ VerificationField: data.VerificationTypeDateOfBirth,
})
payment1 := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{
diff --git a/internal/services/payment_from_submitter_service_test.go b/internal/services/payment_from_submitter_service_test.go
index 5b7c493cc..c009bb13c 100644
--- a/internal/services/payment_from_submitter_service_test.go
+++ b/internal/services/payment_from_submitter_service_test.go
@@ -63,17 +63,13 @@ func Test_PaymentFromSubmitterService_SyncBatchTransactions(t *testing.T) {
asset := data.CreateAssetFixture(t, ctx, dbConnectionPool,
"USDC",
"GABC65XJDMXTGPNZRCI6V3KOKKWVK55UEKGQLONRIVYPMEJNNQ45YOEE")
- country := data.CreateCountryFixture(t, ctx, dbConnectionPool,
- "FRA",
- "France")
// create disbursements
startedDisbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, testCtx.sdpModel.Disbursements, &data.Disbursement{
- Name: "ready disbursement",
- Status: data.StartedDisbursementStatus,
- Asset: asset,
- Wallet: wallet,
- Country: country,
+ Name: "ready disbursement",
+ Status: data.StartedDisbursementStatus,
+ Asset: asset,
+ Wallet: wallet,
})
// create disbursement receivers
@@ -270,17 +266,13 @@ func Test_PaymentFromSubmitterService_SyncTransaction(t *testing.T) {
asset := data.CreateAssetFixture(t, ctx, dbConnectionPool,
"USDC",
"GABC65XJDMXTGPNZRCI6V3KOKKWVK55UEKGQLONRIVYPMEJNNQ45YOEE")
- country := data.CreateCountryFixture(t, ctx, dbConnectionPool,
- "FRA",
- "France")
// create disbursements
startedDisbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, testCtx.sdpModel.Disbursements, &data.Disbursement{
- Name: "ready disbursement",
- Status: data.StartedDisbursementStatus,
- Asset: asset,
- Wallet: wallet,
- Country: country,
+ Name: "ready disbursement",
+ Status: data.StartedDisbursementStatus,
+ Asset: asset,
+ Wallet: wallet,
})
// create disbursement receivers
@@ -578,7 +570,6 @@ func updateTSSTransactionsToError(t *testing.T, testCtx *testContext, txDataSlic
func Test_PaymentFromSubmitterService_RetryingPayment(t *testing.T) {
dbt := dbtest.Open(t)
defer dbt.Close()
-
dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN)
require.NoError(t, outerErr)
defer dbConnectionPool.Close()
@@ -588,17 +579,7 @@ func Test_PaymentFromSubmitterService_RetryingPayment(t *testing.T) {
monitorService := NewPaymentFromSubmitterService(testCtx.sdpModel, dbConnectionPool)
- // clean test db
- data.DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool)
- data.DeleteAllDisbursementFixtures(t, ctx, dbConnectionPool)
- data.DeleteAllReceiverWalletsFixtures(t, ctx, dbConnectionPool)
- data.DeleteAllReceiversFixtures(t, ctx, dbConnectionPool)
- data.DeleteAllAssetFixtures(t, ctx, dbConnectionPool)
- data.DeleteAllWalletFixtures(t, ctx, dbConnectionPool)
- data.DeleteAllCountryFixtures(t, ctx, dbConnectionPool)
-
// create fixtures
- country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "BRA", "Brazil")
wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "Wallet", "https://www.wallet.com", "www.wallet.com", "wallet://")
asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GABC65XJDMXTGPNZRCI6V3KOKKWVK55UEKGQLONRIVYPMEJNNQ45YOEE")
@@ -606,11 +587,10 @@ func Test_PaymentFromSubmitterService_RetryingPayment(t *testing.T) {
receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.RegisteredReceiversWalletStatus)
disbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, testCtx.sdpModel.Disbursements, &data.Disbursement{
- Name: "started disbursement",
- Status: data.StartedDisbursementStatus,
- Asset: asset,
- Wallet: wallet,
- Country: country,
+ Name: "started disbursement",
+ Status: data.StartedDisbursementStatus,
+ Asset: asset,
+ Wallet: wallet,
})
payment := data.CreatePaymentFixture(t, ctx, dbConnectionPool, testCtx.sdpModel.Payment, &data.Payment{
@@ -704,7 +684,6 @@ func Test_PaymentFromSubmitterService_RetryingPayment(t *testing.T) {
func Test_PaymentFromSubmitterService_CompleteDisbursements(t *testing.T) {
dbt := dbtest.Open(t)
defer dbt.Close()
-
dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN)
require.NoError(t, outerErr)
defer dbConnectionPool.Close()
@@ -714,17 +693,7 @@ func Test_PaymentFromSubmitterService_CompleteDisbursements(t *testing.T) {
monitorService := NewPaymentFromSubmitterService(testCtx.sdpModel, dbConnectionPool)
- // clean test db
- data.DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool)
- data.DeleteAllDisbursementFixtures(t, ctx, dbConnectionPool)
- data.DeleteAllReceiverWalletsFixtures(t, ctx, dbConnectionPool)
- data.DeleteAllReceiversFixtures(t, ctx, dbConnectionPool)
- data.DeleteAllAssetFixtures(t, ctx, dbConnectionPool)
- data.DeleteAllWalletFixtures(t, ctx, dbConnectionPool)
- data.DeleteAllCountryFixtures(t, ctx, dbConnectionPool)
-
// create fixtures
- country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "BRA", "Brazil")
wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "Wallet", "https://www.wallet.com", "www.wallet.com", "wallet://")
asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GABC65XJDMXTGPNZRCI6V3KOKKWVK55UEKGQLONRIVYPMEJNNQ45YOEE")
@@ -732,11 +701,10 @@ func Test_PaymentFromSubmitterService_CompleteDisbursements(t *testing.T) {
receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.RegisteredReceiversWalletStatus)
disbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, testCtx.sdpModel.Disbursements, &data.Disbursement{
- Name: "started disbursement",
- Status: data.StartedDisbursementStatus,
- Asset: asset,
- Wallet: wallet,
- Country: country,
+ Name: "started disbursement",
+ Status: data.StartedDisbursementStatus,
+ Asset: asset,
+ Wallet: wallet,
})
payment := data.CreatePaymentFixture(t, ctx, dbConnectionPool, testCtx.sdpModel.Payment, &data.Payment{
diff --git a/internal/services/payment_management_service_test.go b/internal/services/payment_management_service_test.go
index dce483945..7b61447e0 100644
--- a/internal/services/payment_management_service_test.go
+++ b/internal/services/payment_management_service_test.go
@@ -31,15 +31,13 @@ func Test_PaymentManagementService_CancelPayment(t *testing.T) {
// create fixtures
wallet := data.CreateDefaultWalletFixture(t, ctx, dbConnectionPool)
asset := data.GetAssetFixture(t, ctx, dbConnectionPool, data.FixtureAssetUSDC)
- country := data.GetCountryFixture(t, ctx, dbConnectionPool, data.FixtureCountryUSA)
// create disbursements
startedDisbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Name: "ready disbursement",
- Status: data.StartedDisbursementStatus,
- Asset: asset,
- Wallet: wallet,
- Country: country,
+ Name: "ready disbursement",
+ Status: data.StartedDisbursementStatus,
+ Asset: asset,
+ Wallet: wallet,
})
// create disbursement receivers
diff --git a/internal/services/payment_to_submitter_service_test.go b/internal/services/payment_to_submitter_service_test.go
index 8fa2fc734..6dbcc1c43 100644
--- a/internal/services/payment_to_submitter_service_test.go
+++ b/internal/services/payment_to_submitter_service_test.go
@@ -39,7 +39,6 @@ func Test_PaymentToSubmitterService_SendPaymentsMethods(t *testing.T) {
eurcAsset := data.CreateAssetFixture(t, ctx, dbConnectionPool, assets.EURCAssetCode, assets.EURCAssetTestnet.Issuer)
nativeAsset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "XLM", "")
- country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "FRA", "France")
wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "My Wallet", "https://www.wallet.com", "www.wallet.com", "wallet1://")
models, err := data.NewModels(dbConnectionPool)
@@ -129,11 +128,10 @@ func Test_PaymentToSubmitterService_SendPaymentsMethods(t *testing.T) {
defer data.DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool)
startedDisbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Name: "ready disbursement",
- Status: data.StartedDisbursementStatus,
- Asset: tc.asset,
- Wallet: wallet,
- Country: country,
+ Name: "ready disbursement",
+ Status: data.StartedDisbursementStatus,
+ Asset: tc.asset,
+ Wallet: wallet,
})
receiverReady := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{})
@@ -446,7 +444,6 @@ func Test_PaymentToSubmitterService_ValidatePaymentReadyForSending(t *testing.T)
func Test_PaymentToSubmitterService_RetryPayment(t *testing.T) {
dbt := dbtest.Open(t)
defer dbt.Close()
-
dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN)
require.NoError(t, err)
defer dbConnectionPool.Close()
@@ -472,17 +469,7 @@ func Test_PaymentToSubmitterService_RetryPayment(t *testing.T) {
PaymentDispatcher: paymentDispatcher,
})
- // clean test db
- data.DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool)
- data.DeleteAllDisbursementFixtures(t, ctx, dbConnectionPool)
- data.DeleteAllReceiverWalletsFixtures(t, ctx, dbConnectionPool)
- data.DeleteAllReceiversFixtures(t, ctx, dbConnectionPool)
- data.DeleteAllAssetFixtures(t, ctx, dbConnectionPool)
- data.DeleteAllWalletFixtures(t, ctx, dbConnectionPool)
- data.DeleteAllCountryFixtures(t, ctx, dbConnectionPool)
-
// create fixtures
- country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "BRA", "Brazil")
wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "Wallet", "https://www.wallet.com", "www.wallet.com", "wallet://")
asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GDUCE34WW5Z34GMCEPURYANUCUP47J6NORJLKC6GJNMDLN4ZI4PMI2MG")
@@ -490,11 +477,10 @@ func Test_PaymentToSubmitterService_RetryPayment(t *testing.T) {
receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.RegisteredReceiversWalletStatus)
disbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Name: "started disbursement",
- Status: data.StartedDisbursementStatus,
- Asset: asset,
- Wallet: wallet,
- Country: country,
+ Name: "started disbursement",
+ Status: data.StartedDisbursementStatus,
+ Asset: asset,
+ Wallet: wallet,
})
payment := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{
@@ -592,13 +578,11 @@ func Test_PaymentToSubmitterService_markPaymentsAsFailed(t *testing.T) {
models, err := data.NewModels(dbConnectionPool)
require.NoError(t, err)
asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV")
- country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "FRA", "France")
wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "wallet1", "https://www.wallet.com", "www.wallet.com", "wallet1://")
disbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Country: country,
- Wallet: wallet,
- Status: data.ReadyDisbursementStatus,
- Asset: asset,
+ Wallet: wallet,
+ Status: data.ReadyDisbursementStatus,
+ Asset: asset,
})
receiverReady := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{})
rwReady := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiverReady.ID, wallet.ID, data.ReadyReceiversWalletStatus)
diff --git a/internal/services/ready_payments_cancelation_service_test.go b/internal/services/ready_payments_cancelation_service_test.go
index b35dcb3c2..218f5252c 100644
--- a/internal/services/ready_payments_cancelation_service_test.go
+++ b/internal/services/ready_payments_cancelation_service_test.go
@@ -17,7 +17,6 @@ import (
func Test_ReadyPaymentsCancellationService_CancelReadyPaymentsService(t *testing.T) {
dbt := dbtest.Open(t)
defer dbt.Close()
-
dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN)
require.NoError(t, err)
defer dbConnectionPool.Close()
@@ -28,15 +27,6 @@ func Test_ReadyPaymentsCancellationService_CancelReadyPaymentsService(t *testing
service := NewReadyPaymentsCancellationService(models)
ctx := context.Background()
- data.DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool)
- data.DeleteAllDisbursementFixtures(t, ctx, dbConnectionPool)
- data.DeleteAllCountryFixtures(t, ctx, dbConnectionPool)
- data.DeleteAllAssetFixtures(t, ctx, dbConnectionPool)
- data.DeleteAllReceiverWalletsFixtures(t, ctx, dbConnectionPool)
- data.DeleteAllReceiversFixtures(t, ctx, dbConnectionPool)
- data.DeleteAllWalletFixtures(t, ctx, dbConnectionPool)
-
- country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "BRA", "Brazil")
wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "Wallet", "https://www.wallet.com", "www.wallet.com", "wallet://")
asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV")
@@ -44,11 +34,10 @@ func Test_ReadyPaymentsCancellationService_CancelReadyPaymentsService(t *testing
receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.RegisteredReceiversWalletStatus)
disbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Country: country,
Wallet: wallet,
Asset: asset,
Status: data.ReadyDisbursementStatus,
- VerificationField: data.VerificationFieldDateOfBirth,
+ VerificationField: data.VerificationTypeDateOfBirth,
})
t.Run("automatic payment cancellation is deactivated", func(t *testing.T) {
diff --git a/internal/services/send_invitation_message.go b/internal/services/send_invitation_message.go
index c46dc834a..43a6512c8 100644
--- a/internal/services/send_invitation_message.go
+++ b/internal/services/send_invitation_message.go
@@ -57,20 +57,20 @@ func SendInvitationMessage(ctx context.Context, messengerClient message.Messenge
return fmt.Errorf("getting forgot password link: %w", err)
}
- invitationMsgData := htmltemplate.InvitationMessageTemplate{
+ invitationMsgData := htmltemplate.StaffInvitationEmailMessageTemplate{
FirstName: opts.FirstName,
Role: opts.Role,
ForgotPasswordLink: forgotPasswordLink,
OrganizationName: organization.Name,
}
- messageContent, err := htmltemplate.ExecuteHTMLTemplateForInvitationMessage(invitationMsgData)
+ messageContent, err := htmltemplate.ExecuteHTMLTemplateForStaffInvitationEmailMessage(invitationMsgData)
if err != nil {
return fmt.Errorf("executing invitation message HTML template: %w", err)
}
msg := message.Message{
ToEmail: opts.Email,
- Message: messageContent,
+ Body: messageContent,
Title: invitationMessageTitle,
}
diff --git a/internal/services/send_invitation_message_test.go b/internal/services/send_invitation_message_test.go
index 1c493e20d..cdabe6f45 100644
--- a/internal/services/send_invitation_message_test.go
+++ b/internal/services/send_invitation_message_test.go
@@ -81,7 +81,7 @@ func Test_SendInvitationMessage(t *testing.T) {
forgotPasswordLink, err := url.JoinPath(uiBaseURL, "forgot-password")
require.NoError(t, err)
- content, err := htmltemplate.ExecuteHTMLTemplateForInvitationMessage(htmltemplate.InvitationMessageTemplate{
+ content, err := htmltemplate.ExecuteHTMLTemplateForStaffInvitationEmailMessage(htmltemplate.StaffInvitationEmailMessageTemplate{
FirstName: firstName,
Role: roles[0],
ForgotPasswordLink: forgotPasswordLink,
@@ -122,7 +122,7 @@ func Test_SendInvitationMessage(t *testing.T) {
On("SendMessage", message.Message{
ToEmail: email,
Title: invitationMessageTitle,
- Message: content,
+ Body: content,
}).
Return(errors.New("foobar")).
Once()
@@ -142,7 +142,7 @@ func Test_SendInvitationMessage(t *testing.T) {
On("SendMessage", message.Message{
ToEmail: email,
Title: invitationMessageTitle,
- Message: content,
+ Body: content,
}).
Return(nil).
Once()
diff --git a/internal/services/send_receiver_wallets_invite_service.go b/internal/services/send_receiver_wallets_invite_service.go
index 85988ad91..46251b634 100644
--- a/internal/services/send_receiver_wallets_invite_service.go
+++ b/internal/services/send_receiver_wallets_invite_service.go
@@ -23,22 +23,22 @@ import (
)
type SendReceiverWalletInviteServiceInterface interface {
- SendInvite(ctx context.Context, receiverWalletInvitationData ...schemas.EventReceiverWalletSMSInvitationData) error
+ SendInvite(ctx context.Context, receiverWalletInvitationData ...schemas.EventReceiverWalletInvitationData) error
}
type SendReceiverWalletInviteService struct {
- messengerClient message.MessengerClient
- Models *data.Models
- maxInvitationSMSResendAttempts int64
- sep10SigningPrivateKey string
- crashTrackerClient crashtracker.CrashTrackerClient
+ messageDispatcher message.MessageDispatcherInterface
+ Models *data.Models
+ maxInvitationResendAttempts int64
+ sep10SigningPrivateKey string
+ crashTrackerClient crashtracker.CrashTrackerClient
}
var _ SendReceiverWalletInviteServiceInterface = new(SendReceiverWalletInviteService)
func (s SendReceiverWalletInviteService) validate() error {
- if s.messengerClient == nil {
- return fmt.Errorf("messenger client can't be nil")
+ if s.messageDispatcher == nil {
+ return fmt.Errorf("messenger dispatcher can't be nil")
}
return nil
@@ -49,44 +49,44 @@ func (s SendReceiverWalletInviteService) validate() error {
// For instance, the Wallet Foo is in two Ready Payments, one with USDC and the other with EUROC.
// So the receiver who has a Stellar Address pending registration (status:READY) in this wallet will receive both invites for USDC and EUROC.
// This would not impact the user receiving both token amounts. It's only for the registration process.
-func (s SendReceiverWalletInviteService) SendInvite(ctx context.Context, receiverWalletInvitationData ...schemas.EventReceiverWalletSMSInvitationData) error {
+func (s SendReceiverWalletInviteService) SendInvite(ctx context.Context, receiverWalletInvitationData ...schemas.EventReceiverWalletInvitationData) error {
if s.Models == nil {
return fmt.Errorf("SendReceiverWalletInviteService.Models cannot be nil")
}
currentTenant, err := tenant.GetTenantFromContext(ctx)
if err != nil {
- return fmt.Errorf("error getting tenant from context: %w", err)
+ return fmt.Errorf("getting tenant from context: %w", err)
}
if currentTenant.BaseURL == nil {
return fmt.Errorf("tenant base URL cannot be nil for tenant %s", currentTenant.ID)
}
- // Get the organization entry to get the Org name and SMSRegistrationMessageTemplate
+ // Get the organization entry to get the Org name and ReceiverRegistrationMessageTemplate
organization, err := s.Models.Organizations.Get(ctx)
if err != nil {
- return fmt.Errorf("error getting organization: %w", err)
+ return fmt.Errorf("getting organization: %w", err)
}
// Debug purposes
- if organization.SMSResendInterval == nil {
- log.Ctx(ctx).Debug("automatic resend invitation SMS is deactivated. Set a valid value to the organization's sms_resend_interval to activate it.")
+ if organization.ReceiverInvitationResendIntervalDays == nil {
+ log.Ctx(ctx).Debug("automatic resend invitation is deactivated. Set a valid value to the organization's receiver_invitation_resend_interval_days to activate it.")
}
- orgSMSRegistrationMessageTemplate := organization.SMSRegistrationMessageTemplate
- if !strings.Contains(orgSMSRegistrationMessageTemplate, "{{.RegistrationLink}}") {
- orgSMSRegistrationMessageTemplate = fmt.Sprintf("%s {{.RegistrationLink}}", strings.TrimSpace(orgSMSRegistrationMessageTemplate))
+ orgReceiverRegistrationMessageTemplate := organization.ReceiverRegistrationMessageTemplate
+ if !strings.Contains(orgReceiverRegistrationMessageTemplate, "{{.RegistrationLink}}") {
+ orgReceiverRegistrationMessageTemplate = fmt.Sprintf("%s {{.RegistrationLink}}", strings.TrimSpace(orgReceiverRegistrationMessageTemplate))
}
// Execute the template early so we avoid hitting the database to query the other info
- msgTemplate, err := template.New("").Parse(orgSMSRegistrationMessageTemplate)
+ msgTemplate, err := template.New("").Parse(orgReceiverRegistrationMessageTemplate)
if err != nil {
- return fmt.Errorf("error parsing organization SMS registration message template: %w", err)
+ return fmt.Errorf("parsing organization receiver registration message template: %w", err)
}
wallets, err := s.Models.Wallets.GetAll(ctx)
if err != nil {
- return fmt.Errorf("error getting all wallets: %w", err)
+ return fmt.Errorf("getting all wallets: %w", err)
}
walletsMap := make(map[string]data.Wallet, len(wallets))
@@ -96,19 +96,19 @@ func (s SendReceiverWalletInviteService) SendInvite(ctx context.Context, receive
receiverWallets, err := s.resolveReceiverWalletsPendingRegistration(ctx, receiverWalletInvitationData)
if err != nil {
- return fmt.Errorf("error resolving receiver wallets pending registration: %w", err)
+ return fmt.Errorf("resolving receiver wallets pending registration: %w", err)
}
receiverWalletsAsset, err := s.Models.Assets.GetAssetsPerReceiverWallet(ctx, receiverWallets...)
if err != nil {
- return fmt.Errorf("error getting all assets: %w", err)
+ return fmt.Errorf("getting all assets: %w", err)
}
msgsToInsert := []*data.MessageInsert{}
receiverWalletIDs := []string{}
// TODO: improve this code adding go routines
for _, rwa := range receiverWalletsAsset {
- if !s.shouldSendInvitationSMS(ctx, organization, &rwa) {
+ if !s.shouldSendInvitation(ctx, organization, &rwa) {
continue
}
@@ -131,15 +131,15 @@ func (s SendReceiverWalletInviteService) SendInvite(ctx context.Context, receive
continue
}
- disbursementSMSRegistrationMessageTemplate := rwa.DisbursementSMSTemplate
- if disbursementSMSRegistrationMessageTemplate != nil && *disbursementSMSRegistrationMessageTemplate != "" {
- if !strings.Contains(*disbursementSMSRegistrationMessageTemplate, "{{.RegistrationLink}}") {
- *disbursementSMSRegistrationMessageTemplate = fmt.Sprintf("%s {{.RegistrationLink}}", strings.TrimSpace(*disbursementSMSRegistrationMessageTemplate))
+ disbursementReceiverRegistrationMessageTemplate := rwa.DisbursementReceiverRegistrationMsgTemplate
+ if disbursementReceiverRegistrationMessageTemplate != nil && *disbursementReceiverRegistrationMessageTemplate != "" {
+ if !strings.Contains(*disbursementReceiverRegistrationMessageTemplate, "{{.RegistrationLink}}") {
+ *disbursementReceiverRegistrationMessageTemplate = fmt.Sprintf("%s {{.RegistrationLink}}", strings.TrimSpace(*disbursementReceiverRegistrationMessageTemplate))
}
- msgTemplate, err = template.New("").Parse(*disbursementSMSRegistrationMessageTemplate)
+ msgTemplate, err = template.New("").Parse(*disbursementReceiverRegistrationMessageTemplate)
if err != nil {
- return fmt.Errorf("error parsing disbursement SMS registration message template: %w", err)
+ return fmt.Errorf("parsing disbursement receiver registration message template: %w", err)
}
}
@@ -152,42 +152,45 @@ func (s SendReceiverWalletInviteService) SendInvite(ctx context.Context, receive
RegistrationLink: template.HTML(registrationLink),
})
if err != nil {
- return fmt.Errorf("error executing registration message template: %w", err)
+ return fmt.Errorf("executing registration message template: %w", err)
}
- msg := message.Message{
- ToPhoneNumber: rwa.ReceiverWallet.Receiver.PhoneNumber,
- Message: content.String(),
+ msg := message.Message{Body: content.String()}
+ if rwa.ReceiverWallet.Receiver.PhoneNumber != "" {
+ msg.ToPhoneNumber = rwa.ReceiverWallet.Receiver.PhoneNumber
+ }
+ if rwa.ReceiverWallet.Receiver.Email != "" {
+ msg.ToEmail = rwa.ReceiverWallet.Receiver.Email
+ msg.Title = "You have a payment waiting for you from " + organization.Name
}
- assetID := rwa.Asset.ID
- receiverWalletID := rwa.ReceiverWallet.ID
- messageType := s.messengerClient.MessengerType()
msgToInsert := &data.MessageInsert{
- Type: messageType,
- AssetID: &assetID,
+ AssetID: &rwa.Asset.ID,
ReceiverID: rwa.ReceiverWallet.Receiver.ID,
WalletID: wallet.ID,
- ReceiverWalletID: &receiverWalletID,
- TextEncrypted: content.String(),
+ ReceiverWalletID: &rwa.ReceiverWallet.ID,
+ TextEncrypted: msg.Body,
+ TitleEncrypted: msg.Title,
}
- // We assume that the message will be sent at first
- msgToInsert.Status = data.SuccessMessageStatus
- if err := s.messengerClient.SendMessage(msg); err != nil {
- msg := fmt.Sprintf(
+ if messengerType, sendErr := s.messageDispatcher.SendMessage(ctx, msg, organization.MessageChannelPriority); sendErr != nil {
+ errMsg := fmt.Sprintf(
"error sending message to receiver ID %s for receiver wallet ID %s using messenger type %s",
- rwa.ReceiverWallet.Receiver.ID, rwa.ReceiverWallet.ID, messageType,
+ rwa.ReceiverWallet.Receiver.ID, rwa.ReceiverWallet.ID, messengerType,
)
// call crash tracker client to log and report error
- s.crashTrackerClient.LogAndReportErrors(ctx, err, msg)
+ s.crashTrackerClient.LogAndReportErrors(ctx, sendErr, errMsg)
msgToInsert.Status = data.FailureMessageStatus
+ msgToInsert.Type = messengerType
+ } else {
+ msgToInsert.Status = data.SuccessMessageStatus
+ msgToInsert.Type = messengerType
}
msgsToInsert = append(msgsToInsert, msgToInsert)
- // We don't want to update the `invitation_sent_at` for receiver wallets that we've sent the invitation SMS
- // because there's no way to calculate how many times we've resent the invitation SMS since
+ // We don't want to update the `invitation_sent_at` for receiver wallets for which we've already sent the invitation message
+ // because there's no way to calculate how many times we've resent the invitation message since
// the first invitation if we update it.
if rwa.ReceiverWallet.InvitationSentAt == nil && msgToInsert.Status == data.SuccessMessageStatus {
receiverWalletIDs = append(receiverWalletIDs, rwa.ReceiverWallet.ID)
@@ -200,7 +203,7 @@ func (s SendReceiverWalletInviteService) SendInvite(ctx context.Context, receive
}
if err := s.Models.Message.BulkInsert(ctx, dbTx, msgsToInsert); err != nil {
- return fmt.Errorf("error inserting messages in the database: %w", err)
+ return fmt.Errorf("inserting messages in the database: %w", err)
}
return nil
@@ -209,7 +212,7 @@ func (s SendReceiverWalletInviteService) SendInvite(ctx context.Context, receive
// resolveReceiverWalletsPendingRegistration returns the receiver wallets pending registration based on the receiverWalletInvitationData.
// If the receiverWalletInvitationData is empty, it will return all receiver wallets pending registration.
-func (s SendReceiverWalletInviteService) resolveReceiverWalletsPendingRegistration(ctx context.Context, receiverWalletInvitationData []schemas.EventReceiverWalletSMSInvitationData) ([]*data.ReceiverWallet, error) {
+func (s SendReceiverWalletInviteService) resolveReceiverWalletsPendingRegistration(ctx context.Context, receiverWalletInvitationData []schemas.EventReceiverWalletInvitationData) ([]*data.ReceiverWallet, error) {
var err error
var receiverWallets []*data.ReceiverWallet
if len(receiverWalletInvitationData) == 0 {
@@ -230,52 +233,49 @@ func (s SendReceiverWalletInviteService) resolveReceiverWalletsPendingRegistrati
return receiverWallets, err
}
-// shouldSendInvitationSMS returns true if we should send the invitation SMS to the receiver. It will be used to either
-// send the invitation for the first time, or to resend it automatically according with the organization's SMS Resend
-// Interval and the maximum number of SMS resend attempts.
-
-func (s SendReceiverWalletInviteService) shouldSendInvitationSMS(ctx context.Context, organization *data.Organization, rwa *data.ReceiverWalletAsset) bool {
- truncatedPhoneNumber := utils.TruncateString(rwa.ReceiverWallet.Receiver.PhoneNumber, 3)
+// shouldSendInvitation returns true if we should send the invitation to the receiver. It will be used to either
+// send the invitation for the first time, or to resend it automatically according to the organization's Resend
+// Interval and the maximum number of resend attempts.
+func (s SendReceiverWalletInviteService) shouldSendInvitation(ctx context.Context, organization *data.Organization, rwa *data.ReceiverWalletAsset) bool {
+ receiver := rwa.ReceiverWallet.Receiver
- // We've never sent a Invitation SMS
+ // We've never sent an Invitation message
if rwa.ReceiverWallet.InvitationSentAt == nil {
return true
}
- // If organization's SMS Resend Interval is nil and we've sent the invitation message to the receiver, we won't resend it.
- if organization.SMSResendInterval == nil && rwa.ReceiverWallet.InvitationSentAt != nil {
+ // If organization's Receiver Invitation Resend Interval is nil and we've sent the invitation message to the receiver, we won't resend it.
+ if organization.ReceiverInvitationResendIntervalDays == nil && rwa.ReceiverWallet.InvitationSentAt != nil {
log.Ctx(ctx).Debugf(
- "the invitation message was not automatically resent to the receiver %s with phone number %s because the organization's SMS Resend Interval is nil",
- rwa.ReceiverWallet.Receiver.ID, truncatedPhoneNumber)
+ "the invitation message was not automatically resent to the receiver %s because the organization's Receiver Invitation Resend Interval is nil",
+ receiver.ID)
return false
}
- // The organizations has a interval to automatic resend the Invitation SMS.
- if organization.SMSResendInterval != nil {
- // Check if the receiver wallet reached the maximum number of SMS resend attempts.
- if rwa.ReceiverWallet.ReceiverWalletStats.TotalInvitationSMSResentAttempts >= s.maxInvitationSMSResendAttempts {
+ // The organizations defined an interval to automatically resend the receiver invitation message.
+ if organization.ReceiverInvitationResendIntervalDays != nil {
+ // Check if the receiver wallet reached the maximum number of resend attempts.
+ if rwa.ReceiverWallet.ReceiverWalletStats.TotalInvitationResentAttempts >= s.maxInvitationResendAttempts {
log.Ctx(ctx).Debugf(
- "the invitation message was not resent to the receiver because the maximum number of SMS resend attempts has been reached: Phone Number: %s - Receiver ID %s - Wallet ID %s - Total Invitation SMS resent %d - Maximum attempts %d",
- truncatedPhoneNumber,
- rwa.ReceiverWallet.Receiver.ID,
+ "the invitation message was not resent to the receiver because the maximum number of message resend attempts has been reached: Receiver ID %s - Wallet ID %s - Total Invitation resent %d - Maximum attempts %d",
+ receiver.ID,
rwa.WalletID,
- rwa.ReceiverWallet.ReceiverWalletStats.TotalInvitationSMSResentAttempts,
- s.maxInvitationSMSResendAttempts,
+ rwa.ReceiverWallet.ReceiverWalletStats.TotalInvitationResentAttempts,
+ s.maxInvitationResendAttempts,
)
return false
}
// Check if it's in the period to resend it.
resendPeriod := time.Now().
- AddDate(0, 0, -int(*organization.SMSResendInterval*(rwa.ReceiverWallet.ReceiverWalletStats.TotalInvitationSMSResentAttempts+1)))
+ AddDate(0, 0, -int(*organization.ReceiverInvitationResendIntervalDays*(rwa.ReceiverWallet.ReceiverWalletStats.TotalInvitationResentAttempts+1)))
if !rwa.ReceiverWallet.InvitationSentAt.Before(resendPeriod) {
log.Ctx(ctx).Debugf(
- "the invitation message was not automatically resent to the receiver because the receiver is not in the resend period: Phone Number: %s - Receiver ID %s - Wallet ID %s - Last Invitation Sent At %s - SMS Resend Interval %d day(s)",
- truncatedPhoneNumber,
- rwa.ReceiverWallet.Receiver.ID,
+ "the invitation message was not automatically resent to the receiver because the receiver is not in the resend period: Receiver ID %s - Wallet ID %s - Last Invitation Sent At %s - Receiver Invitation Resend Interval %d day(s)",
+ receiver.ID,
rwa.WalletID,
rwa.ReceiverWallet.InvitationSentAt.Format(time.RFC1123),
- *organization.SMSResendInterval,
+ *organization.ReceiverInvitationResendIntervalDays,
)
return false
}
@@ -284,13 +284,13 @@ func (s SendReceiverWalletInviteService) shouldSendInvitationSMS(ctx context.Con
return true
}
-func NewSendReceiverWalletInviteService(models *data.Models, messengerClient message.MessengerClient, sep10SigningPrivateKey string, maxInvitationSMSResendAttempts int64, crashTrackerClient crashtracker.CrashTrackerClient) (*SendReceiverWalletInviteService, error) {
+func NewSendReceiverWalletInviteService(models *data.Models, messageDispatcher message.MessageDispatcherInterface, sep10SigningPrivateKey string, maxInvitationResendAttempts int64, crashTrackerClient crashtracker.CrashTrackerClient) (*SendReceiverWalletInviteService, error) {
s := &SendReceiverWalletInviteService{
- messengerClient: messengerClient,
- Models: models,
- maxInvitationSMSResendAttempts: maxInvitationSMSResendAttempts,
- sep10SigningPrivateKey: sep10SigningPrivateKey,
- crashTrackerClient: crashTrackerClient,
+ messageDispatcher: messageDispatcher,
+ Models: models,
+ maxInvitationResendAttempts: maxInvitationResendAttempts,
+ sep10SigningPrivateKey: sep10SigningPrivateKey,
+ crashTrackerClient: crashTrackerClient,
}
if err := s.validate(); err != nil {
diff --git a/internal/services/send_receiver_wallets_invite_service_test.go b/internal/services/send_receiver_wallets_invite_service_test.go
index 9abbc4eb0..b71b9f5ba 100644
--- a/internal/services/send_receiver_wallets_invite_service_test.go
+++ b/internal/services/send_receiver_wallets_invite_service_test.go
@@ -10,6 +10,7 @@ import (
"github.com/google/uuid"
"github.com/stellar/go/support/log"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/stellar/stellar-disbursement-platform-backend/db"
@@ -50,7 +51,7 @@ func Test_GetSignedRegistrationLink_SchemelessDeepLink(t *testing.T) {
require.Equal(t, wantRegistrationLink, registrationLink)
}
-func Test_SendReceiverWalletInviteService(t *testing.T) {
+func Test_SendReceiverWalletInviteService_SendInvite(t *testing.T) {
dbt := dbtest.Open(t)
defer dbt.Close()
@@ -63,19 +64,13 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
ctx := tenant.SaveTenantInContext(context.Background(), tenantInfo)
stellarSecretKey := "SBUSPEKAZKLZSWHRSJ2HWDZUK6I3IVDUWA7JJZSGBLZ2WZIUJI7FPNB5"
- messengerClientMock := &message.MessengerClientMock{}
- messengerClientMock.
- On("MessengerType").
- Return(message.MessengerTypeTwilioSMS).
- Maybe()
+ messageDispatcherMock := message.NewMockMessageDispatcher(t)
mockCrashTrackerClient := &crashtracker.MockCrashTrackerClient{}
models, err := data.NewModels(dbConnectionPool)
require.NoError(t, err)
- country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "ATL", "Atlantis")
-
wallet1 := data.CreateWalletFixture(t, ctx, dbConnectionPool, "Wallet1", "https://wallet1.com", "www.wallet1.com", "wallet1://sdp")
wallet2 := data.CreateWalletFixture(t, ctx, dbConnectionPool, "Wallet2", "https://wallet2.com", "www.wallet2.com", "wallet2://sdp")
@@ -84,28 +79,32 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
receiver1 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{})
receiver2 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{})
+ receiverEmailOnly := data.InsertReceiverFixture(t, ctx, dbConnectionPool, &data.ReceiverInsert{
+ Email: utils.StringPtr("emailJWP5O@randomemail.com"),
+ })
+ receiverPhoneOnly := data.InsertReceiverFixture(t, ctx, dbConnectionPool, &data.ReceiverInsert{
+ PhoneNumber: utils.StringPtr("1234567890"),
+ })
disbursement1 := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Country: country,
- Wallet: wallet1,
- Status: data.ReadyDisbursementStatus,
- Asset: asset1,
+ Wallet: wallet1,
+ Status: data.ReadyDisbursementStatus,
+ Asset: asset1,
})
disbursement2 := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Country: country,
- Wallet: wallet2,
- Status: data.ReadyDisbursementStatus,
- Asset: asset2,
+ Wallet: wallet2,
+ Status: data.ReadyDisbursementStatus,
+ Asset: asset2,
})
t.Run("returns error when service has wrong setup", func(t *testing.T) {
_, err := NewSendReceiverWalletInviteService(models, nil, stellarSecretKey, 3, mockCrashTrackerClient)
- assert.EqualError(t, err, "invalid service setup: messenger client can't be nil")
+ assert.EqualError(t, err, "invalid service setup: messenger dispatcher can't be nil")
})
t.Run("inserts the failed sent message", func(t *testing.T) {
- s, err := NewSendReceiverWalletInviteService(models, messengerClientMock, stellarSecretKey, 3, mockCrashTrackerClient)
+ s, err := NewSendReceiverWalletInviteService(models, messageDispatcherMock, stellarSecretKey, 3, mockCrashTrackerClient)
require.NoError(t, err)
data.DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool)
@@ -143,6 +142,7 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
deepLink1, err := walletDeepLink1.GetSignedRegistrationLink(stellarSecretKey)
require.NoError(t, err)
contentWallet1 := fmt.Sprintf("You have a payment waiting for you from the MyCustomAid. Click %s to register.", deepLink1)
+ titleWallet1 := "You have a payment waiting for you from " + walletDeepLink1.OrganizationName
walletDeepLink2 := WalletDeepLink{
DeepLink: wallet2.DeepLinkSchema,
@@ -154,20 +154,25 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
deepLink2, err := walletDeepLink2.GetSignedRegistrationLink(stellarSecretKey)
require.NoError(t, err)
contentWallet2 := fmt.Sprintf("You have a payment waiting for you from the MyCustomAid. Click %s to register.", deepLink2)
+ titleWallet2 := "You have a payment waiting for you from " + walletDeepLink2.OrganizationName
mockErr := errors.New("unexpected error")
- messengerClientMock.
- On("SendMessage", message.Message{
+ messageDispatcherMock.
+ On("SendMessage", mock.Anything, message.Message{
ToPhoneNumber: receiver1.PhoneNumber,
- Message: contentWallet1,
- }).
- Return(errors.New("unexpected error")).
+ ToEmail: receiver1.Email,
+ Body: contentWallet1,
+ Title: titleWallet1,
+ }, []message.MessageChannel{message.MessageChannelSMS, message.MessageChannelEmail}).
+ Return(message.MessengerTypeTwilioSMS, errors.New("unexpected error")).
Once().
- On("SendMessage", message.Message{
+ On("SendMessage", mock.Anything, message.Message{
ToPhoneNumber: receiver2.PhoneNumber,
- Message: contentWallet2,
- }).
- Return(nil).
+ ToEmail: receiver2.Email,
+ Body: contentWallet2,
+ Title: titleWallet2,
+ }, []message.MessageChannel{message.MessageChannelSMS, message.MessageChannelEmail}).
+ Return(message.MessengerTypeTwilioSMS, nil).
Once()
mockMsg := fmt.Sprintf(
@@ -176,7 +181,7 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
)
mockCrashTrackerClient.On("LogAndReportErrors", ctx, mockErr, mockMsg).Once()
- reqs := []schemas.EventReceiverWalletSMSInvitationData{
+ reqs := []schemas.EventReceiverWalletInvitationData{
{
ReceiverWalletID: rec1RW.ID,
},
@@ -218,7 +223,7 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
assert.Equal(t, wallet1.ID, msg.WalletID)
assert.Equal(t, rec1RW.ID, *msg.ReceiverWalletID)
assert.Equal(t, data.FailureMessageStatus, msg.Status)
- assert.Empty(t, msg.TitleEncrypted)
+ assert.Equal(t, titleWallet1, msg.TitleEncrypted)
assert.Equal(t, contentWallet1, msg.TextEncrypted)
assert.Len(t, msg.StatusHistory, 2)
assert.Equal(t, data.PendingMessageStatus, msg.StatusHistory[0].Status)
@@ -234,7 +239,7 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
assert.Equal(t, wallet2.ID, msg.WalletID)
assert.Equal(t, rec2RW.ID, *msg.ReceiverWalletID)
assert.Equal(t, data.SuccessMessageStatus, msg.Status)
- assert.Empty(t, msg.TitleEncrypted)
+ assert.Equal(t, titleWallet2, msg.TitleEncrypted)
assert.Equal(t, contentWallet2, msg.TextEncrypted)
assert.Len(t, msg.StatusHistory, 2)
assert.Equal(t, data.PendingMessageStatus, msg.StatusHistory[0].Status)
@@ -245,17 +250,17 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
})
t.Run("send invite successfully", func(t *testing.T) {
- s, err := NewSendReceiverWalletInviteService(models, messengerClientMock, stellarSecretKey, 3, mockCrashTrackerClient)
+ s, err := NewSendReceiverWalletInviteService(models, messageDispatcherMock, stellarSecretKey, 3, mockCrashTrackerClient)
require.NoError(t, err)
data.DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool)
data.DeleteAllMessagesFixtures(t, ctx, dbConnectionPool)
data.DeleteAllReceiverWalletsFixtures(t, ctx, dbConnectionPool)
- rec1RW := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver1.ID, wallet1.ID, data.ReadyReceiversWalletStatus)
- data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver1.ID, wallet2.ID, data.RegisteredReceiversWalletStatus)
+ rec1RW := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiverPhoneOnly.ID, wallet1.ID, data.ReadyReceiversWalletStatus)
+ data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiverPhoneOnly.ID, wallet2.ID, data.RegisteredReceiversWalletStatus)
- rec2RW := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver2.ID, wallet2.ID, data.ReadyReceiversWalletStatus)
+ rec2RW := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiverEmailOnly.ID, wallet2.ID, data.ReadyReceiversWalletStatus)
_ = data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{
Status: data.ReadyPaymentStatus,
@@ -283,6 +288,7 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
deepLink1, err := walletDeepLink1.GetSignedRegistrationLink(stellarSecretKey)
require.NoError(t, err)
contentWallet1 := fmt.Sprintf("You have a payment waiting for you from the MyCustomAid. Click %s to register.", deepLink1)
+ // titleWallet1 := "You have a payment waiting for you from " + walletDeepLink1.OrganizationName
walletDeepLink2 := WalletDeepLink{
DeepLink: wallet2.DeepLinkSchema,
@@ -294,22 +300,24 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
deepLink2, err := walletDeepLink2.GetSignedRegistrationLink(stellarSecretKey)
require.NoError(t, err)
contentWallet2 := fmt.Sprintf("You have a payment waiting for you from the MyCustomAid. Click %s to register.", deepLink2)
-
- messengerClientMock.
- On("SendMessage", message.Message{
- ToPhoneNumber: receiver1.PhoneNumber,
- Message: contentWallet1,
- }).
- Return(nil).
+ titleWallet2 := "You have a payment waiting for you from " + walletDeepLink2.OrganizationName
+
+ messageDispatcherMock.
+ On("SendMessage", mock.Anything, message.Message{
+ ToPhoneNumber: receiverPhoneOnly.PhoneNumber,
+ Body: contentWallet1,
+ }, []message.MessageChannel{message.MessageChannelSMS, message.MessageChannelEmail}).
+ Return(message.MessengerTypeTwilioSMS, nil).
Once().
- On("SendMessage", message.Message{
- ToPhoneNumber: receiver2.PhoneNumber,
- Message: contentWallet2,
- }).
- Return(nil).
+ On("SendMessage", mock.Anything, message.Message{
+ ToEmail: receiverEmailOnly.Email,
+ Body: contentWallet2,
+ Title: titleWallet2,
+ }, []message.MessageChannel{message.MessageChannelSMS, message.MessageChannelEmail}).
+ Return(message.MessengerTypeAWSEmail, nil).
Once()
- reqs := []schemas.EventReceiverWalletSMSInvitationData{
+ reqs := []schemas.EventReceiverWalletInvitationData{
{
ReceiverWalletID: rec1RW.ID,
},
@@ -321,13 +329,13 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
err = s.SendInvite(ctx, reqs...)
require.NoError(t, err)
- receivers, err := models.ReceiverWallet.GetByReceiverIDsAndWalletID(ctx, dbConnectionPool, []string{receiver1.ID}, wallet1.ID)
+ receivers, err := models.ReceiverWallet.GetByReceiverIDsAndWalletID(ctx, dbConnectionPool, []string{receiverPhoneOnly.ID}, wallet1.ID)
require.NoError(t, err)
require.Len(t, receivers, 1)
assert.Equal(t, rec1RW.ID, receivers[0].ID)
assert.NotNil(t, receivers[0].InvitationSentAt)
- receivers, err = models.ReceiverWallet.GetByReceiverIDsAndWalletID(ctx, dbConnectionPool, []string{receiver2.ID}, wallet2.ID)
+ receivers, err = models.ReceiverWallet.GetByReceiverIDsAndWalletID(ctx, dbConnectionPool, []string{receiverEmailOnly.ID}, wallet2.ID)
require.NoError(t, err)
require.Len(t, receivers, 1)
assert.Equal(t, rec2RW.ID, receivers[0].ID)
@@ -343,11 +351,11 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
receiver_id = $1 AND wallet_id = $2 AND receiver_wallet_id = $3
`
var msg data.Message
- err = dbConnectionPool.GetContext(ctx, &msg, q, receiver1.ID, wallet1.ID, rec1RW.ID)
+ err = dbConnectionPool.GetContext(ctx, &msg, q, receiverPhoneOnly.ID, wallet1.ID, rec1RW.ID)
require.NoError(t, err)
assert.Equal(t, message.MessengerTypeTwilioSMS, msg.Type)
- assert.Equal(t, receiver1.ID, msg.ReceiverID)
+ assert.Equal(t, receiverPhoneOnly.ID, msg.ReceiverID)
assert.Equal(t, wallet1.ID, msg.WalletID)
assert.Equal(t, rec1RW.ID, *msg.ReceiverWalletID)
assert.Equal(t, data.SuccessMessageStatus, msg.Status)
@@ -359,15 +367,15 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
assert.Nil(t, msg.AssetID)
msg = data.Message{}
- err = dbConnectionPool.GetContext(ctx, &msg, q, receiver2.ID, wallet2.ID, rec2RW.ID)
+ err = dbConnectionPool.GetContext(ctx, &msg, q, receiverEmailOnly.ID, wallet2.ID, rec2RW.ID)
require.NoError(t, err)
- assert.Equal(t, message.MessengerTypeTwilioSMS, msg.Type)
- assert.Equal(t, receiver2.ID, msg.ReceiverID)
+ assert.Equal(t, message.MessengerTypeAWSEmail, msg.Type)
+ assert.Equal(t, receiverEmailOnly.ID, msg.ReceiverID)
assert.Equal(t, wallet2.ID, msg.WalletID)
assert.Equal(t, rec2RW.ID, *msg.ReceiverWalletID)
assert.Equal(t, data.SuccessMessageStatus, msg.Status)
- assert.Empty(t, msg.TitleEncrypted)
+ assert.Equal(t, titleWallet2, msg.TitleEncrypted)
assert.Equal(t, contentWallet2, msg.TextEncrypted)
assert.Len(t, msg.StatusHistory, 2)
assert.Equal(t, data.PendingMessageStatus, msg.StatusHistory[0].Status)
@@ -376,7 +384,7 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
})
t.Run("send invite successfully with custom invite message", func(t *testing.T) {
- s, err := NewSendReceiverWalletInviteService(models, messengerClientMock, stellarSecretKey, 3, mockCrashTrackerClient)
+ s, err := NewSendReceiverWalletInviteService(models, messageDispatcherMock, stellarSecretKey, 3, mockCrashTrackerClient)
require.NoError(t, err)
data.DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool)
@@ -389,7 +397,7 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
rec2RW := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver2.ID, wallet2.ID, data.ReadyReceiversWalletStatus)
customInvitationMessage := "My custom receiver wallet registration invite. MyOrg 👋"
- err = models.Organizations.Update(ctx, &data.OrganizationUpdate{SMSRegistrationMessageTemplate: &customInvitationMessage})
+ err = models.Organizations.Update(ctx, &data.OrganizationUpdate{ReceiverRegistrationMessageTemplate: &customInvitationMessage})
require.NoError(t, err)
_ = data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{
@@ -418,6 +426,7 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
deepLink1, err := walletDeepLink1.GetSignedRegistrationLink(stellarSecretKey)
require.NoError(t, err)
contentWallet1 := fmt.Sprintf("%s %s", customInvitationMessage, deepLink1)
+ titleWallet1 := "You have a payment waiting for you from " + walletDeepLink1.OrganizationName
walletDeepLink2 := WalletDeepLink{
DeepLink: wallet2.DeepLinkSchema,
@@ -429,22 +438,27 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
deepLink2, err := walletDeepLink2.GetSignedRegistrationLink(stellarSecretKey)
require.NoError(t, err)
contentWallet2 := fmt.Sprintf("%s %s", customInvitationMessage, deepLink2)
+ titleWallet2 := "You have a payment waiting for you from " + walletDeepLink2.OrganizationName
- messengerClientMock.
- On("SendMessage", message.Message{
+ messageDispatcherMock.
+ On("SendMessage", mock.Anything, message.Message{
ToPhoneNumber: receiver1.PhoneNumber,
- Message: contentWallet1,
- }).
- Return(nil).
+ ToEmail: receiver1.Email,
+ Body: contentWallet1,
+ Title: titleWallet1,
+ }, []message.MessageChannel{message.MessageChannelSMS, message.MessageChannelEmail}).
+ Return(message.MessengerTypeTwilioSMS, nil).
Once().
- On("SendMessage", message.Message{
+ On("SendMessage", mock.Anything, message.Message{
ToPhoneNumber: receiver2.PhoneNumber,
- Message: contentWallet2,
- }).
- Return(nil).
+ ToEmail: receiver2.Email,
+ Body: contentWallet2,
+ Title: titleWallet2,
+ }, []message.MessageChannel{message.MessageChannelSMS, message.MessageChannelEmail}).
+ Return(message.MessengerTypeTwilioSMS, nil).
Once()
- reqs := []schemas.EventReceiverWalletSMSInvitationData{
+ reqs := []schemas.EventReceiverWalletInvitationData{
{
ReceiverWalletID: rec1RW.ID,
},
@@ -486,7 +500,7 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
assert.Equal(t, wallet1.ID, msg.WalletID)
assert.Equal(t, rec1RW.ID, *msg.ReceiverWalletID)
assert.Equal(t, data.SuccessMessageStatus, msg.Status)
- assert.Empty(t, msg.TitleEncrypted)
+ assert.Equal(t, titleWallet1, msg.TitleEncrypted)
assert.Equal(t, contentWallet1, msg.TextEncrypted)
assert.Len(t, msg.StatusHistory, 2)
assert.Equal(t, data.PendingMessageStatus, msg.StatusHistory[0].Status)
@@ -502,7 +516,7 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
assert.Equal(t, wallet2.ID, msg.WalletID)
assert.Equal(t, rec2RW.ID, *msg.ReceiverWalletID)
assert.Equal(t, data.SuccessMessageStatus, msg.Status)
- assert.Empty(t, msg.TitleEncrypted)
+ assert.Equal(t, titleWallet1, msg.TitleEncrypted)
assert.Equal(t, contentWallet2, msg.TextEncrypted)
assert.Len(t, msg.StatusHistory, 2)
assert.Equal(t, data.PendingMessageStatus, msg.StatusHistory[0].Status)
@@ -511,7 +525,7 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
})
t.Run("doesn't resend the invitation SMS when organization's SMS Resend Interval is nil and the invitation was already sent", func(t *testing.T) {
- s, err := NewSendReceiverWalletInviteService(models, messengerClientMock, stellarSecretKey, 3, mockCrashTrackerClient)
+ s, err := NewSendReceiverWalletInviteService(models, messageDispatcherMock, stellarSecretKey, 3, mockCrashTrackerClient)
require.NoError(t, err)
data.DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool)
@@ -535,10 +549,10 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
err = dbConnectionPool.GetContext(ctx, &invitationSentAt, q, rec1RW.ID)
require.NoError(t, err)
- err = models.Organizations.Update(ctx, &data.OrganizationUpdate{SMSResendInterval: new(int64)})
+ err = models.Organizations.Update(ctx, &data.OrganizationUpdate{ReceiverInvitationResendIntervalDays: new(int64)})
require.NoError(t, err)
- reqs := []schemas.EventReceiverWalletSMSInvitationData{
+ reqs := []schemas.EventReceiverWalletInvitationData{
{
ReceiverWalletID: rec1RW.ID,
},
@@ -556,7 +570,7 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
})
t.Run("doesn't resend the invitation SMS when receiver reached the maximum number of resend attempts", func(t *testing.T) {
- s, err := NewSendReceiverWalletInviteService(models, messengerClientMock, stellarSecretKey, 3, mockCrashTrackerClient)
+ s, err := NewSendReceiverWalletInviteService(models, messageDispatcherMock, stellarSecretKey, 3, mockCrashTrackerClient)
require.NoError(t, err)
data.DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool)
@@ -582,7 +596,7 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
// Set the SMS Resend Interval
var smsResendInterval int64 = 2
- err = models.Organizations.Update(ctx, &data.OrganizationUpdate{SMSResendInterval: &smsResendInterval})
+ err = models.Organizations.Update(ctx, &data.OrganizationUpdate{ReceiverInvitationResendIntervalDays: &smsResendInterval})
require.NoError(t, err)
_ = data.CreateMessageFixture(t, ctx, dbConnectionPool, &data.Message{
@@ -618,7 +632,7 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
UpdatedAt: time.Now().AddDate(0, 0, int(smsResendInterval*3)),
})
- reqs := []schemas.EventReceiverWalletSMSInvitationData{
+ reqs := []schemas.EventReceiverWalletInvitationData{
{
ReceiverWalletID: rec1RW.ID,
},
@@ -636,7 +650,7 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
})
t.Run("doesn't resend invitation SMS when receiver is not in the resend period", func(t *testing.T) {
- s, err := NewSendReceiverWalletInviteService(models, messengerClientMock, stellarSecretKey, 3, mockCrashTrackerClient)
+ s, err := NewSendReceiverWalletInviteService(models, messageDispatcherMock, stellarSecretKey, 3, mockCrashTrackerClient)
require.NoError(t, err)
data.DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool)
@@ -662,10 +676,10 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
// Set the SMS Resend Interval
var smsResendInterval int64 = 2
- err = models.Organizations.Update(ctx, &data.OrganizationUpdate{SMSResendInterval: &smsResendInterval})
+ err = models.Organizations.Update(ctx, &data.OrganizationUpdate{ReceiverInvitationResendIntervalDays: &smsResendInterval})
require.NoError(t, err)
- reqs := []schemas.EventReceiverWalletSMSInvitationData{
+ reqs := []schemas.EventReceiverWalletInvitationData{
{
ReceiverWalletID: rec1RW.ID,
},
@@ -683,7 +697,7 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
})
t.Run("successfully resend the invitation SMS", func(t *testing.T) {
- s, err := NewSendReceiverWalletInviteService(models, messengerClientMock, stellarSecretKey, 3, mockCrashTrackerClient)
+ s, err := NewSendReceiverWalletInviteService(models, messageDispatcherMock, stellarSecretKey, 3, mockCrashTrackerClient)
require.NoError(t, err)
data.DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool)
@@ -709,7 +723,7 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
// Set the SMS Resend Interval
var smsResendInterval int64 = 2
- err = models.Organizations.Update(ctx, &data.OrganizationUpdate{SMSResendInterval: &smsResendInterval, SMSRegistrationMessageTemplate: new(string)})
+ err = models.Organizations.Update(ctx, &data.OrganizationUpdate{ReceiverInvitationResendIntervalDays: &smsResendInterval, ReceiverRegistrationMessageTemplate: new(string)})
require.NoError(t, err)
walletDeepLink1 := WalletDeepLink{
@@ -722,16 +736,19 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
deepLink1, err := walletDeepLink1.GetSignedRegistrationLink(stellarSecretKey)
require.NoError(t, err)
contentWallet1 := fmt.Sprintf("You have a payment waiting for you from the MyCustomAid. Click %s to register.", deepLink1)
+ titleWallet1 := "You have a payment waiting for you from " + walletDeepLink1.OrganizationName
- messengerClientMock.
- On("SendMessage", message.Message{
+ messageDispatcherMock.
+ On("SendMessage", mock.Anything, message.Message{
ToPhoneNumber: receiver1.PhoneNumber,
- Message: contentWallet1,
- }).
- Return(nil).
+ ToEmail: receiver1.Email,
+ Body: contentWallet1,
+ Title: titleWallet1,
+ }, []message.MessageChannel{message.MessageChannelSMS, message.MessageChannelEmail}).
+ Return(message.MessengerTypeTwilioSMS, nil).
Once()
- reqs := []schemas.EventReceiverWalletSMSInvitationData{
+ reqs := []schemas.EventReceiverWalletInvitationData{
{
ReceiverWalletID: rec1RW.ID,
},
@@ -766,7 +783,7 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
assert.Equal(t, wallet1.ID, msg.WalletID)
assert.Equal(t, rec1RW.ID, *msg.ReceiverWalletID)
assert.Equal(t, data.SuccessMessageStatus, msg.Status)
- assert.Empty(t, msg.TitleEncrypted)
+ assert.Equal(t, titleWallet1, msg.TitleEncrypted)
assert.Equal(t, contentWallet1, msg.TextEncrypted)
assert.Len(t, msg.StatusHistory, 2)
assert.Equal(t, data.PendingMessageStatus, msg.StatusHistory[0].Status)
@@ -776,22 +793,20 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
t.Run("send disbursement invite successfully", func(t *testing.T) {
disbursement3 := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Country: country,
- Wallet: wallet1,
- Status: data.ReadyDisbursementStatus,
- Asset: asset1,
- SMSRegistrationMessageTemplate: "SMS Registration Message template test disbursement 3:",
+ Wallet: wallet1,
+ Status: data.ReadyDisbursementStatus,
+ Asset: asset1,
+ ReceiverRegistrationMessageTemplate: "SMS Registration Message template test disbursement 3:",
})
disbursement4 := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Country: country,
- Wallet: wallet2,
- Status: data.ReadyDisbursementStatus,
- Asset: asset2,
- SMSRegistrationMessageTemplate: "SMS Registration Message template test disbursement 4:",
+ Wallet: wallet2,
+ Status: data.ReadyDisbursementStatus,
+ Asset: asset2,
+ ReceiverRegistrationMessageTemplate: "SMS Registration Message template test disbursement 4:",
})
- s, err := NewSendReceiverWalletInviteService(models, messengerClientMock, stellarSecretKey, 3, mockCrashTrackerClient)
+ s, err := NewSendReceiverWalletInviteService(models, messageDispatcherMock, stellarSecretKey, 3, mockCrashTrackerClient)
require.NoError(t, err)
data.DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool)
@@ -828,7 +843,8 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
}
deepLink1, err := walletDeepLink1.GetSignedRegistrationLink(stellarSecretKey)
require.NoError(t, err)
- contentDisbursement3 := fmt.Sprintf("%s %s", disbursement3.SMSRegistrationMessageTemplate, deepLink1)
+ contentDisbursement3 := fmt.Sprintf("%s %s", disbursement3.ReceiverRegistrationMessageTemplate, deepLink1)
+ titleDisbursement3 := "You have a payment waiting for you from " + walletDeepLink1.OrganizationName
walletDeepLink2 := WalletDeepLink{
DeepLink: wallet2.DeepLinkSchema,
@@ -839,23 +855,28 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
}
deepLink2, err := walletDeepLink2.GetSignedRegistrationLink(stellarSecretKey)
require.NoError(t, err)
- contentDisbursement4 := fmt.Sprintf("%s %s", disbursement4.SMSRegistrationMessageTemplate, deepLink2)
+ contentDisbursement4 := fmt.Sprintf("%s %s", disbursement4.ReceiverRegistrationMessageTemplate, deepLink2)
+ titleDisbursement4 := "You have a payment waiting for you from " + walletDeepLink2.OrganizationName
- messengerClientMock.
- On("SendMessage", message.Message{
+ messageDispatcherMock.
+ On("SendMessage", mock.Anything, message.Message{
ToPhoneNumber: receiver1.PhoneNumber,
- Message: contentDisbursement3,
- }).
- Return(nil).
+ ToEmail: receiver1.Email,
+ Body: contentDisbursement3,
+ Title: titleDisbursement3,
+ }, []message.MessageChannel{message.MessageChannelSMS, message.MessageChannelEmail}).
+ Return(message.MessengerTypeTwilioSMS, nil).
Once().
- On("SendMessage", message.Message{
+ On("SendMessage", mock.Anything, message.Message{
ToPhoneNumber: receiver2.PhoneNumber,
- Message: contentDisbursement4,
- }).
- Return(nil).
+ ToEmail: receiver2.Email,
+ Body: contentDisbursement4,
+ Title: titleDisbursement4,
+ }, []message.MessageChannel{message.MessageChannelSMS, message.MessageChannelEmail}).
+ Return(message.MessengerTypeTwilioSMS, nil).
Once()
- reqs := []schemas.EventReceiverWalletSMSInvitationData{
+ reqs := []schemas.EventReceiverWalletInvitationData{
{
ReceiverWalletID: rec1RW.ID,
},
@@ -897,7 +918,7 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
assert.Equal(t, wallet1.ID, msg.WalletID)
assert.Equal(t, rec1RW.ID, *msg.ReceiverWalletID)
assert.Equal(t, data.SuccessMessageStatus, msg.Status)
- assert.Empty(t, msg.TitleEncrypted)
+ assert.Equal(t, titleDisbursement3, msg.TitleEncrypted)
assert.Equal(t, contentDisbursement3, msg.TextEncrypted)
assert.Len(t, msg.StatusHistory, 2)
assert.Equal(t, data.PendingMessageStatus, msg.StatusHistory[0].Status)
@@ -913,7 +934,7 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
assert.Equal(t, wallet2.ID, msg.WalletID)
assert.Equal(t, rec2RW.ID, *msg.ReceiverWalletID)
assert.Equal(t, data.SuccessMessageStatus, msg.Status)
- assert.Empty(t, msg.TitleEncrypted)
+ assert.Equal(t, titleDisbursement4, msg.TitleEncrypted)
assert.Equal(t, contentDisbursement4, msg.TextEncrypted)
assert.Len(t, msg.StatusHistory, 2)
assert.Equal(t, data.PendingMessageStatus, msg.StatusHistory[0].Status)
@@ -923,14 +944,13 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
t.Run("successfully resend the disbursement invitation SMS", func(t *testing.T) {
disbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Country: country,
- Wallet: wallet1,
- Status: data.ReadyDisbursementStatus,
- Asset: asset1,
- SMSRegistrationMessageTemplate: "SMS Registration Message template test disbursement:",
+ Wallet: wallet1,
+ Status: data.ReadyDisbursementStatus,
+ Asset: asset1,
+ ReceiverRegistrationMessageTemplate: "SMS Registration Message template test disbursement:",
})
- s, err := NewSendReceiverWalletInviteService(models, messengerClientMock, stellarSecretKey, 3, mockCrashTrackerClient)
+ s, err := NewSendReceiverWalletInviteService(models, messageDispatcherMock, stellarSecretKey, 3, mockCrashTrackerClient)
require.NoError(t, err)
data.DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool)
@@ -956,7 +976,7 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
// Set the SMS Resend Interval
var smsResendInterval int64 = 2
- err = models.Organizations.Update(ctx, &data.OrganizationUpdate{SMSResendInterval: &smsResendInterval, SMSRegistrationMessageTemplate: new(string)})
+ err = models.Organizations.Update(ctx, &data.OrganizationUpdate{ReceiverInvitationResendIntervalDays: &smsResendInterval, ReceiverRegistrationMessageTemplate: new(string)})
require.NoError(t, err)
walletDeepLink1 := WalletDeepLink{
@@ -968,17 +988,20 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
}
deepLink1, err := walletDeepLink1.GetSignedRegistrationLink(stellarSecretKey)
require.NoError(t, err)
- contentDisbursement := fmt.Sprintf("%s %s", disbursement.SMSRegistrationMessageTemplate, deepLink1)
+ contentDisbursement := fmt.Sprintf("%s %s", disbursement.ReceiverRegistrationMessageTemplate, deepLink1)
+ titleDisbursement := "You have a payment waiting for you from " + walletDeepLink1.OrganizationName
- messengerClientMock.
- On("SendMessage", message.Message{
+ messageDispatcherMock.
+ On("SendMessage", mock.Anything, message.Message{
ToPhoneNumber: receiver1.PhoneNumber,
- Message: contentDisbursement,
- }).
- Return(nil).
+ ToEmail: receiver1.Email,
+ Body: contentDisbursement,
+ Title: titleDisbursement,
+ }, []message.MessageChannel{message.MessageChannelSMS, message.MessageChannelEmail}).
+ Return(message.MessengerTypeTwilioSMS, nil).
Once()
- reqs := []schemas.EventReceiverWalletSMSInvitationData{
+ reqs := []schemas.EventReceiverWalletInvitationData{
{
ReceiverWalletID: rec1RW.ID,
},
@@ -1013,7 +1036,7 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
assert.Equal(t, wallet1.ID, msg.WalletID)
assert.Equal(t, rec1RW.ID, *msg.ReceiverWalletID)
assert.Equal(t, data.SuccessMessageStatus, msg.Status)
- assert.Empty(t, msg.TitleEncrypted)
+ assert.Equal(t, titleDisbursement, msg.TitleEncrypted)
assert.Equal(t, contentDisbursement, msg.TextEncrypted)
assert.Len(t, msg.StatusHistory, 2)
assert.Equal(t, data.PendingMessageStatus, msg.StatusHistory[0].Status)
@@ -1021,41 +1044,44 @@ func Test_SendReceiverWalletInviteService(t *testing.T) {
assert.Nil(t, msg.AssetID)
})
- messengerClientMock.AssertExpectations(t)
+ messageDispatcherMock.AssertExpectations(t)
}
-func Test_SendReceiverWalletInviteService_shouldSendInvitationSMS(t *testing.T) {
- var maxInvitationSMSResendAttempts int64 = 3
- s := SendReceiverWalletInviteService{maxInvitationSMSResendAttempts: maxInvitationSMSResendAttempts}
+func Test_SendReceiverWalletInviteService_shouldSendInvitation(t *testing.T) {
+ var maxInvitationResendAttempts int64 = 3
+ s := SendReceiverWalletInviteService{maxInvitationResendAttempts: maxInvitationResendAttempts}
ctx := context.Background()
t.Run("returns true when user never received the invitation SMS", func(t *testing.T) {
- org := data.Organization{SMSResendInterval: nil}
+ org := data.Organization{ReceiverInvitationResendIntervalDays: nil}
rwa := data.ReceiverWalletAsset{
ReceiverWallet: data.ReceiverWallet{
InvitationSentAt: nil,
+ Receiver: data.Receiver{
+ PhoneNumber: "+380443973607",
+ },
},
}
- got := s.shouldSendInvitationSMS(ctx, &org, &rwa)
+ got := s.shouldSendInvitation(ctx, &org, &rwa)
assert.True(t, got)
})
t.Run("returns false when user received the invitation SMS and organization's SMS Resend Interval is not set", func(t *testing.T) {
invitationSentAt := time.Now()
- org := data.Organization{SMSResendInterval: nil}
+ org := data.Organization{ReceiverInvitationResendIntervalDays: nil}
rwa := data.ReceiverWalletAsset{
ReceiverWallet: data.ReceiverWallet{
InvitationSentAt: &invitationSentAt,
},
}
- got := s.shouldSendInvitationSMS(ctx, &org, &rwa)
+ got := s.shouldSendInvitation(ctx, &org, &rwa)
assert.False(t, got)
})
- t.Run("returns false when receiver reached the maximum number of SMS resend attempts", func(t *testing.T) {
- var smsResendInterval int64 = 2
+ t.Run("returns false when receiver reached the maximum number of message resend attempts", func(t *testing.T) {
+ var msgResendInterval int64 = 2
invitationSentAt := time.Now()
- org := data.Organization{SMSResendInterval: &smsResendInterval}
+ org := data.Organization{ReceiverInvitationResendIntervalDays: &msgResendInterval}
rwa := data.ReceiverWalletAsset{
ReceiverWallet: data.ReceiverWallet{
InvitationSentAt: &invitationSentAt,
@@ -1064,7 +1090,7 @@ func Test_SendReceiverWalletInviteService_shouldSendInvitationSMS(t *testing.T)
PhoneNumber: "+123456789",
},
ReceiverWalletStats: data.ReceiverWalletStats{
- TotalInvitationSMSResentAttempts: maxInvitationSMSResendAttempts,
+ TotalInvitationResentAttempts: maxInvitationResendAttempts,
},
},
WalletID: "wallet-ID",
@@ -1072,22 +1098,22 @@ func Test_SendReceiverWalletInviteService_shouldSendInvitationSMS(t *testing.T)
getEntries := log.DefaultLogger.StartTest(log.DebugLevel)
- got := s.shouldSendInvitationSMS(ctx, &org, &rwa)
+ got := s.shouldSendInvitation(ctx, &org, &rwa)
assert.False(t, got)
entries := getEntries()
require.Len(t, entries, 1)
assert.Equal(
t,
- "the invitation message was not resent to the receiver because the maximum number of SMS resend attempts has been reached: Phone Number: +12...789 - Receiver ID receiver-ID - Wallet ID wallet-ID - Total Invitation SMS resent 3 - Maximum attempts 3",
+ "the invitation message was not resent to the receiver because the maximum number of message resend attempts has been reached: Receiver ID receiver-ID - Wallet ID wallet-ID - Total Invitation resent 3 - Maximum attempts 3",
entries[0].Message,
)
})
- t.Run("returns false when the receiver is not in the period to resend the SMS", func(t *testing.T) {
+ t.Run("returns false when the receiver is not in the period to resend the message", func(t *testing.T) {
var smsResendInterval int64 = 2
invitationSentAt := time.Now().AddDate(0, 0, -int(smsResendInterval-1))
- org := data.Organization{SMSResendInterval: &smsResendInterval}
+ org := data.Organization{ReceiverInvitationResendIntervalDays: &smsResendInterval}
rwa := data.ReceiverWalletAsset{
ReceiverWallet: data.ReceiverWallet{
InvitationSentAt: &invitationSentAt,
@@ -1096,7 +1122,7 @@ func Test_SendReceiverWalletInviteService_shouldSendInvitationSMS(t *testing.T)
PhoneNumber: "+123456789",
},
ReceiverWalletStats: data.ReceiverWalletStats{
- TotalInvitationSMSResentAttempts: 1,
+ TotalInvitationResentAttempts: 1,
},
},
WalletID: "wallet-ID",
@@ -1104,7 +1130,7 @@ func Test_SendReceiverWalletInviteService_shouldSendInvitationSMS(t *testing.T)
getEntries := log.DefaultLogger.StartTest(log.DebugLevel)
- got := s.shouldSendInvitationSMS(ctx, &org, &rwa)
+ got := s.shouldSendInvitation(ctx, &org, &rwa)
assert.False(t, got)
entries := getEntries()
@@ -1112,7 +1138,7 @@ func Test_SendReceiverWalletInviteService_shouldSendInvitationSMS(t *testing.T)
assert.Equal(
t,
fmt.Sprintf(
- "the invitation message was not automatically resent to the receiver because the receiver is not in the resend period: Phone Number: +12...789 - Receiver ID receiver-ID - Wallet ID wallet-ID - Last Invitation Sent At %s - SMS Resend Interval 2 day(s)",
+ "the invitation message was not automatically resent to the receiver because the receiver is not in the resend period: Receiver ID receiver-ID - Wallet ID wallet-ID - Last Invitation Sent At %s - Receiver Invitation Resend Interval 2 day(s)",
invitationSentAt.Format(time.RFC1123),
),
entries[0].Message,
@@ -1124,16 +1150,19 @@ func Test_SendReceiverWalletInviteService_shouldSendInvitationSMS(t *testing.T)
// 2 days after receiving the first invitation
invitationSentAt := time.Now().Add((-25 * 2) * time.Hour)
- org := data.Organization{SMSResendInterval: &smsResendInterval}
+ org := data.Organization{ReceiverInvitationResendIntervalDays: &smsResendInterval}
rwa := data.ReceiverWalletAsset{
ReceiverWallet: data.ReceiverWallet{
InvitationSentAt: &invitationSentAt,
ReceiverWalletStats: data.ReceiverWalletStats{
- TotalInvitationSMSResentAttempts: 0,
+ TotalInvitationResentAttempts: 0,
+ },
+ Receiver: data.Receiver{
+ PhoneNumber: "+380443973607",
},
},
}
- got := s.shouldSendInvitationSMS(ctx, &org, &rwa)
+ got := s.shouldSendInvitation(ctx, &org, &rwa)
assert.True(t, got)
// 4 days after receiving the first invitation
@@ -1142,11 +1171,14 @@ func Test_SendReceiverWalletInviteService_shouldSendInvitationSMS(t *testing.T)
ReceiverWallet: data.ReceiverWallet{
InvitationSentAt: &invitationSentAt,
ReceiverWalletStats: data.ReceiverWalletStats{
- TotalInvitationSMSResentAttempts: 1,
+ TotalInvitationResentAttempts: 1,
+ },
+ Receiver: data.Receiver{
+ PhoneNumber: "+380443973607",
},
},
}
- got = s.shouldSendInvitationSMS(ctx, &org, &rwa)
+ got = s.shouldSendInvitation(ctx, &org, &rwa)
assert.True(t, got)
// 6 days after receiving the first invitation
@@ -1155,11 +1187,14 @@ func Test_SendReceiverWalletInviteService_shouldSendInvitationSMS(t *testing.T)
ReceiverWallet: data.ReceiverWallet{
InvitationSentAt: &invitationSentAt,
ReceiverWalletStats: data.ReceiverWalletStats{
- TotalInvitationSMSResentAttempts: 2,
+ TotalInvitationResentAttempts: 2,
+ },
+ Receiver: data.Receiver{
+ PhoneNumber: "+380443973607",
},
},
}
- got = s.shouldSendInvitationSMS(ctx, &org, &rwa)
+ got = s.shouldSendInvitation(ctx, &org, &rwa)
assert.True(t, got)
// 8 days after receiving the first invitation - we don't resend because it reached the maximum number of attempts
@@ -1168,11 +1203,14 @@ func Test_SendReceiverWalletInviteService_shouldSendInvitationSMS(t *testing.T)
ReceiverWallet: data.ReceiverWallet{
InvitationSentAt: &invitationSentAt,
ReceiverWalletStats: data.ReceiverWalletStats{
- TotalInvitationSMSResentAttempts: 3,
+ TotalInvitationResentAttempts: 3,
+ },
+ Receiver: data.Receiver{
+ PhoneNumber: "+380443973607",
},
},
}
- got = s.shouldSendInvitationSMS(ctx, &org, &rwa)
+ got = s.shouldSendInvitation(ctx, &org, &rwa)
assert.False(t, got)
})
}
diff --git a/internal/services/setup_wallets_for_network_service.go b/internal/services/setup_wallets_for_network_service.go
index 8d6084d81..51b96a4bd 100644
--- a/internal/services/setup_wallets_for_network_service.go
+++ b/internal/services/setup_wallets_for_network_service.go
@@ -32,6 +32,7 @@ func SetupWalletsForProperNetwork(ctx context.Context, dbConnectionPool db.DBCon
}
var names, homepages, deepLinkSchemas, sep10ClientDomains []string
+ var userManagedFlags []bool
separator := strings.Repeat("-", 20)
buf := new(strings.Builder)
@@ -41,6 +42,7 @@ func SetupWalletsForProperNetwork(ctx context.Context, dbConnectionPool db.DBCon
homepages = append(homepages, wallet.Homepage)
deepLinkSchemas = append(deepLinkSchemas, wallet.DeepLinkSchema)
sep10ClientDomains = append(sep10ClientDomains, wallet.SEP10ClientDomain)
+ userManagedFlags = append(userManagedFlags, wallet.UserManaged)
buf.WriteString(fmt.Sprintf("%s\n%s\n\n", wallet.Name, separator))
}
@@ -54,7 +56,8 @@ func SetupWalletsForProperNetwork(ctx context.Context, dbConnectionPool db.DBCon
-- gather all wallets passed as parameters for the query and turn into SQL rows
SELECT
UNNEST($1::text[]) AS name, UNNEST($2::text[]) AS homepage,
- UNNEST($3::text[]) AS deep_link_schema, UNNEST($4::text[]) AS sep_10_client_domain
+ UNNEST($3::text[]) AS deep_link_schema, UNNEST($4::text[]) AS sep_10_client_domain,
+ UNNEST($5::bool[]) AS user_managed
),
existing_wallets AS (
-- gets all wallets that the name appears in the names passed as parameter for the query
@@ -82,16 +85,16 @@ func SetupWalletsForProperNetwork(ctx context.Context, dbConnectionPool db.DBCon
)
-- inserts wallets in the database
INSERT INTO wallets
- (name, homepage, deep_link_schema, sep_10_client_domain)
+ (name, homepage, deep_link_schema, sep_10_client_domain, user_managed)
SELECT
- wtui.name, wtui.homepage, wtui.deep_link_schema, wtui.sep_10_client_domain
+ wtui.name, wtui.homepage, wtui.deep_link_schema, wtui.sep_10_client_domain, wtui.user_managed
FROM
wallets_to_update_or_insert wtui
WHERE
wtui.name NOT IN (SELECT name FROM existing_wallets)
`
- _, err := dbTx.ExecContext(ctx, query, pq.Array(names), pq.Array(homepages), pq.Array(deepLinkSchemas), pq.Array(sep10ClientDomains))
+ _, err := dbTx.ExecContext(ctx, query, pq.Array(names), pq.Array(homepages), pq.Array(deepLinkSchemas), pq.Array(sep10ClientDomains), pq.Array(userManagedFlags))
if err != nil {
return fmt.Errorf("error upserting wallets: %w", err)
}
diff --git a/internal/services/wallets/wallets_pubnet.go b/internal/services/wallets/wallets_pubnet.go
index a1840d8fa..d5eb89728 100644
--- a/internal/services/wallets/wallets_pubnet.go
+++ b/internal/services/wallets/wallets_pubnet.go
@@ -24,6 +24,11 @@ var PubnetWallets = []data.Wallet{
assets.USDCAssetPubnet,
},
},
+ {
+ Name: "User Managed Wallet",
+ Assets: assets.AllAssetsPubnet,
+ UserManaged: true,
+ },
// {
// Name: "Beans App",
// Homepage: "https://www.beansapp.com/disbursements",
diff --git a/internal/services/wallets/wallets_testnet.go b/internal/services/wallets/wallets_testnet.go
index bfde96118..d374f99f7 100644
--- a/internal/services/wallets/wallets_testnet.go
+++ b/internal/services/wallets/wallets_testnet.go
@@ -25,4 +25,9 @@ var TestnetWallets = []data.Wallet{
assets.USDCAssetTestnet,
},
},
+ {
+ Name: "User Managed Wallet",
+ Assets: assets.AllAssetsTestnet,
+ UserManaged: true,
+ },
}
diff --git a/internal/statistics/calculate_statistics_test.go b/internal/statistics/calculate_statistics_test.go
index 13b7f404a..f0739e9f0 100644
--- a/internal/statistics/calculate_statistics_test.go
+++ b/internal/statistics/calculate_statistics_test.go
@@ -97,7 +97,6 @@ func TestCalculateStatistics(t *testing.T) {
require.NoError(t, err)
asset1 := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV")
- country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "FRA", "France")
wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "wallet1", "https://www.wallet.com", "www.wallet.com", "wallet1://")
receiver1 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{})
@@ -107,11 +106,10 @@ func TestCalculateStatistics(t *testing.T) {
receiverWallet2 := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver2.ID, wallet.ID, data.DraftReceiversWalletStatus)
disbursement1 := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Name: "disbursement 1",
- Status: data.CompletedDisbursementStatus,
- Asset: asset1,
- Wallet: wallet,
- Country: country,
+ Name: "disbursement 1",
+ Status: data.CompletedDisbursementStatus,
+ Asset: asset1,
+ Wallet: wallet,
})
stellarTransactionID, err := utils.RandomString(64)
@@ -220,11 +218,10 @@ func TestCalculateStatistics(t *testing.T) {
asset2 := data.CreateAssetFixture(t, ctx, dbConnectionPool, "EURT", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV")
disbursement2 := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{
- Name: "disbursement 2",
- Status: data.CompletedDisbursementStatus,
- Asset: asset2,
- Wallet: wallet,
- Country: country,
+ Name: "disbursement 2",
+ Status: data.CompletedDisbursementStatus,
+ Asset: asset2,
+ Wallet: wallet,
})
stellarTransactionID, err = utils.RandomString(64)
@@ -401,7 +398,6 @@ func Test_checkIfDisbursementExists(t *testing.T) {
t.Run("disbursement exists", func(t *testing.T) {
asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV")
- country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "FRA", "France")
wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "wallet1", "https://www.wallet.com", "www.wallet.com", "wallet1://")
disbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, model.Disbursements, &data.Disbursement{
@@ -412,9 +408,8 @@ func Test_checkIfDisbursementExists(t *testing.T) {
UserID: "user1",
},
},
- Asset: asset,
- Country: country,
- Wallet: wallet,
+ Asset: asset,
+ Wallet: wallet,
})
exists, err := checkIfDisbursementExists(context.Background(), dbConnectionPool, disbursement.ID)
require.NoError(t, err)
diff --git a/internal/utils/network_type.go b/internal/utils/network_type.go
index 86bd3de5d..428bbf881 100644
--- a/internal/utils/network_type.go
+++ b/internal/utils/network_type.go
@@ -28,6 +28,14 @@ func (n NetworkType) Validate() error {
return nil
}
+func (n NetworkType) IsPubnet() bool {
+ return n == PubnetNetworkType
+}
+
+func (n NetworkType) IsTestnet() bool {
+ return n == TestnetNetworkType
+}
+
func GetNetworkTypeFromNetworkPassphrase(networkPassphrase string) (NetworkType, error) {
switch networkPassphrase {
case network.PublicNetworkPassphrase:
diff --git a/internal/utils/network_type_test.go b/internal/utils/network_type_test.go
index c7b8b20dd..e2f21b739 100644
--- a/internal/utils/network_type_test.go
+++ b/internal/utils/network_type_test.go
@@ -44,6 +44,58 @@ func Test_NetworkType_Validate(t *testing.T) {
}
}
+func Test_NetworkType_IsTestnet(t *testing.T) {
+ testCases := []struct {
+ networkType NetworkType
+ expectedResult bool
+ }{
+ {
+ networkType: TestnetNetworkType,
+ expectedResult: true,
+ },
+ {
+ networkType: PubnetNetworkType,
+ expectedResult: false,
+ },
+ {
+ networkType: "UNSUPPORTED",
+ expectedResult: false,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(string(tc.networkType), func(t *testing.T) {
+ assert.Equal(t, tc.expectedResult, tc.networkType.IsTestnet())
+ })
+ }
+}
+
+func Test_NetworkType_IsPubnet(t *testing.T) {
+ testCases := []struct {
+ networkType NetworkType
+ expectedResult bool
+ }{
+ {
+ networkType: TestnetNetworkType,
+ expectedResult: false,
+ },
+ {
+ networkType: PubnetNetworkType,
+ expectedResult: true,
+ },
+ {
+ networkType: "UNSUPPORTED",
+ expectedResult: false,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(string(tc.networkType), func(t *testing.T) {
+ assert.Equal(t, tc.expectedResult, tc.networkType.IsPubnet())
+ })
+ }
+}
+
func Test_GetNetworkTypeFromNetworkPassphrase(t *testing.T) {
testCases := []struct {
networkPassphrase string
diff --git a/internal/utils/sql.go b/internal/utils/sql.go
new file mode 100644
index 000000000..c3f4b4e29
--- /dev/null
+++ b/internal/utils/sql.go
@@ -0,0 +1,10 @@
+package utils
+
+import "database/sql"
+
+func SQLNullString(s string) sql.NullString {
+ return sql.NullString{
+ String: s,
+ Valid: s != "",
+ }
+}
diff --git a/internal/utils/string.go b/internal/utils/string.go
index e072e336d..df2f77f0b 100644
--- a/internal/utils/string.go
+++ b/internal/utils/string.go
@@ -4,6 +4,7 @@ import (
"crypto/rand"
"fmt"
"math/big"
+ "strings"
)
const (
@@ -38,3 +39,13 @@ func TruncateString(str string, borderSizeToKeep int) string {
}
return str[:borderSizeToKeep] + "..." + str[len(str)-borderSizeToKeep:]
}
+
+// TrimAndLower trims and lowercases a string.
+func TrimAndLower(str string) string {
+ return strings.TrimSpace(strings.ToLower(str))
+}
+
+// Humanize converts a string to a human readable format.
+func Humanize(str string) string {
+ return strings.ToLower(strings.ReplaceAll(str, "_", " "))
+}
diff --git a/internal/utils/utils.go b/internal/utils/utils.go
index 3bcca2950..bab048c05 100644
--- a/internal/utils/utils.go
+++ b/internal/utils/utils.go
@@ -5,6 +5,7 @@ import (
"fmt"
"net/http"
"reflect"
+ "strconv"
"strings"
"time"
@@ -95,6 +96,31 @@ func StringPtr(s string) *string {
return &s
}
+// IntPtr returns a pointer to an int
+func IntPtr(i int) *int {
+ return &i
+}
+
func TimePtr(t time.Time) *time.Time {
return &t
}
+
+func VisualBool(b bool) string {
+ if b {
+ return "🟢"
+ }
+ return "🔴"
+}
+
+// ParseBoolQueryParam parses a boolean query parameter from an HTTP request.
+func ParseBoolQueryParam(r *http.Request, param string) (*bool, error) {
+ paramValue := r.URL.Query().Get(param)
+ if paramValue == "" {
+ return nil, nil
+ }
+ parsedValue, err := strconv.ParseBool(paramValue)
+ if err != nil {
+ return nil, fmt.Errorf("invalid '%s' parameter value: %w", param, err)
+ }
+ return &parsedValue, nil
+}
diff --git a/internal/utils/utils_test.go b/internal/utils/utils_test.go
index dddce9f95..4cc60e7fd 100644
--- a/internal/utils/utils_test.go
+++ b/internal/utils/utils_test.go
@@ -275,3 +275,57 @@ func TestStringPtr(t *testing.T) {
assert.Equal(t, "initial string", *result)
})
}
+
+// Write a test for ParseBoolQueryParam function.
+func Test_ParseBoolQueryParam(t *testing.T) {
+ trueValue := true
+ falseValue := false
+
+ testCases := []struct {
+ name string
+ queryParam string
+ expectedResult *bool
+ expectedError string
+ }{
+ {
+ name: "valid true value",
+ queryParam: "true",
+ expectedResult: &trueValue,
+ expectedError: "",
+ },
+ {
+ name: "valid false value",
+ queryParam: "false",
+ expectedResult: &falseValue,
+ expectedError: "",
+ },
+ {
+ name: "valid empty value",
+ queryParam: "",
+ expectedResult: nil,
+ expectedError: "",
+ },
+ {
+ name: "invalid value",
+ queryParam: "invalid",
+ expectedResult: nil,
+ expectedError: "invalid 'enabled' parameter value",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ req, err := http.NewRequest("GET", fmt.Sprintf("/?enabled=%s", tc.queryParam), nil)
+ require.NoError(t, err)
+
+ result, err := ParseBoolQueryParam(req, "enabled")
+ if tc.expectedError != "" {
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), tc.expectedError)
+ } else {
+ require.NoError(t, err)
+ assert.Equal(t, tc.expectedResult, result)
+ }
+ })
+ }
+}
diff --git a/internal/utils/validation.go b/internal/utils/validation.go
index 0a7591656..96a335d37 100644
--- a/internal/utils/validation.go
+++ b/internal/utils/validation.go
@@ -1,8 +1,11 @@
package utils
import (
+ "errors"
"fmt"
+ "net/url"
"regexp"
+ "slices"
"strconv"
"time"
@@ -16,6 +19,7 @@ var (
rxOTP = regexp.MustCompile(`^\d{6}$`)
ErrInvalidE164PhoneNumber = fmt.Errorf("the provided phone number is not a valid E.164 number")
ErrEmptyPhoneNumber = fmt.Errorf("phone number cannot be empty")
+ ErrEmptyEmail = fmt.Errorf("email cannot be empty")
)
const (
@@ -67,7 +71,7 @@ var rxEmail = regexp.MustCompile("^[a-zA-Z0-9.!#$%&'*+\\/=?^_`{|}~-]+@[a-zA-Z0-9
func ValidateEmail(email string) error {
if email == "" {
- return fmt.Errorf("email cannot be empty")
+ return ErrEmptyEmail
}
if !rxEmail.MatchString(email) {
@@ -161,3 +165,53 @@ func ValidateNationalIDVerification(nationalID string) error {
return nil
}
+
+// ValidatePathIsNotTraversal will validate the given path to ensure it does not contain path traversal.
+func ValidatePathIsNotTraversal(p string) error {
+ if pathTraversalPattern.MatchString(p) {
+ return errors.New("path cannot contain path traversal")
+ }
+
+ return nil
+}
+
+var pathTraversalPattern = regexp.MustCompile(`(^|[\\/])\.\.([\\/]|$)`)
+
+// ValidateURLScheme checks if a URL is valid and if it has a valid scheme.
+func ValidateURLScheme(link string, scheme ...string) error {
+ // Use govalidator to check if it's a valid URL
+ if !govalidator.IsURL(link) {
+ return errors.New("invalid URL format")
+ }
+
+ parsedURL, err := url.ParseRequestURI(link)
+ if err != nil {
+ return errors.New("invalid URL format")
+ }
+
+ // Check if the scheme is valid
+ if len(scheme) > 0 {
+ if !slices.Contains(scheme, parsedURL.Scheme) {
+ return fmt.Errorf("invalid URL scheme is not part of %v", scheme)
+ }
+ }
+
+ return nil
+}
+
+// ValidateNoHTMLNorJSNorCSS detects HTML,