diff --git a/backend/config/config.go b/backend/config/config.go index 3633e230e..ab8ee6018 100644 --- a/backend/config/config.go +++ b/backend/config/config.go @@ -30,6 +30,7 @@ type Config struct { ThirdParty ThirdParty `yaml:"third_party" json:"third_party" koanf:"third_party" split_words:"true"` Log LoggerConfig `yaml:"log" json:"log" koanf:"log"` Account Account `yaml:"account" json:"account" koanf:"account"` + OIDC OIDC `yaml:"oidc" json:"oidc" koanf:"oidc"` } var ( @@ -601,6 +602,20 @@ func (c *Config) PostProcess() error { } +type OIDCClient struct { + ClientID string `yaml:"client_id" json:"client_id" koanf:"client_id"` + ClientSecret string `yaml:"client_secret" json:"client_secret" koanf:"client_secret"` + ClientType string `yaml:"client_type" json:"client_type" koanf:"client_type"` + RedirectURI []string `yaml:"redirect_uri" json:"redirect_uri" koanf:"redirect_uri"` +} + +type OIDC struct { + Enabled bool `yaml:"enabled" json:"enabled" koanf:"enabled"` + Issuer string `yaml:"issuer" json:"issuer" koanf:"issuer"` + Key string `yaml:"key" json:"key" koanf:"key"` + Clients []OIDCClient `yaml:"clients" json:"clients" koanf:"clients"` +} + type LoggerConfig struct { LogHealthAndMetrics bool `yaml:"log_health_and_metrics" json:"log_health_and_metrics" koanf:"log_health_and_metrics"` } diff --git a/backend/config/config.yaml b/backend/config/config.yaml index 4484cca07..00f40c44c 100644 --- a/backend/config/config.yaml +++ b/backend/config/config.yaml @@ -1,6 +1,5 @@ database: - user: hanko - password: hanko + user: postgres host: localhost port: 5432 dialect: postgres diff --git a/backend/go.mod b/backend/go.mod index 600f66211..03e14261f 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -31,10 +31,12 @@ require ( github.com/sethvargo/go-redisstore v0.3.0 github.com/spf13/cobra v1.7.0 github.com/stretchr/testify v1.8.4 + github.com/zitadel/oidc/v2 v2.6.3 golang.org/x/crypto v0.10.0 golang.org/x/oauth2 v0.9.0 golang.org/x/text v0.10.0 gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df + gopkg.in/square/go-jose.v2 v2.6.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -78,11 +80,14 @@ require ( github.com/gogo/protobuf v1.3.2 // indirect github.com/golang-jwt/jwt v3.2.2+incompatible // indirect github.com/golang-jwt/jwt/v4 v4.5.0 // indirect - github.com/golang/protobuf v1.5.2 // indirect + github.com/golang/protobuf v1.5.3 // indirect github.com/google/go-tpm v0.3.3 // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect github.com/google/uuid v1.3.0 // indirect github.com/gorilla/css v1.0.0 // indirect + github.com/gorilla/mux v1.8.0 // indirect + github.com/gorilla/schema v1.2.0 // indirect + github.com/gorilla/securecookie v1.1.1 // indirect github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect github.com/iancoleman/orderedmap v0.0.0-20190318233801-ac98e3ecb4b0 // indirect github.com/imdario/mergo v0.3.13 // indirect @@ -120,6 +125,7 @@ require ( github.com/moby/term v0.0.0-20221205130635-1aeaba878587 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/muhlemmer/gu v0.3.1 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.0.3-0.20211202183452-c5a74bcca799 // indirect github.com/opencontainers/runc v1.1.5 // indirect @@ -132,10 +138,11 @@ require ( github.com/prometheus/common v0.40.0 // indirect github.com/prometheus/procfs v0.9.0 // indirect github.com/rogpeppe/go-internal v1.10.0 // indirect + github.com/rs/cors v1.9.0 // indirect github.com/segmentio/asm v1.2.0 // indirect github.com/sergi/go-diff v1.2.0 // indirect github.com/shopspring/decimal v1.3.1 // indirect - github.com/sirupsen/logrus v1.9.0 // indirect + github.com/sirupsen/logrus v1.9.3 // indirect github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d // indirect github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e // indirect github.com/spf13/pflag v1.0.5 // indirect @@ -154,7 +161,7 @@ require ( golang.org/x/time v0.3.0 // indirect golang.org/x/tools v0.7.0 // indirect google.golang.org/appengine v1.6.7 // indirect - google.golang.org/protobuf v1.28.1 // indirect + google.golang.org/protobuf v1.29.1 // indirect gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect gopkg.in/yaml.v2 v2.4.0 // indirect ) diff --git a/backend/go.sum b/backend/go.sum index 16cd7898d..2dc5f1848 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -203,6 +203,7 @@ github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190129154638-5b532d6fd5ef/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -217,8 +218,9 @@ github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QD github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= +github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/gomodule/redigo v1.8.2/go.mod h1:P9dn9mFrCBvWhGE1wpxx6fgq7BAeLBk+UUUzlpkBYO0= github.com/gomodule/redigo v1.8.9 h1:Sl3u+2BI/kk+VEatbj0scLdrFhjPmbxOc1myhDP41ws= @@ -251,6 +253,12 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY= github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c= +github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= +github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= +github.com/gorilla/schema v1.2.0 h1:YufUaxZYCKGFuAq3c96BOhjgd5nmXiOY9NGzF247Tsc= +github.com/gorilla/schema v1.2.0/go.mod h1:kgLaKoK1FELgZqMAVxx/5cbj0kT+57qxUrAlIO2eleU= +github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= +github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= @@ -493,6 +501,8 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= github.com/mrunalp/fileutils v0.5.0/go.mod h1:M1WthSahJixYnrXQl/DFQuteStB1weuxD2QJNHXfbSQ= +github.com/muhlemmer/gu v0.3.1 h1:7EAqmFrW7n3hETvuAdmFmn4hS8W+z3LgKtrnow+YzNM= +github.com/muhlemmer/gu v0.3.1/go.mod h1:YHtHR+gxM+bKEIIs7Hmi9sPT3ZDUvTN/i88wQpZkrdM= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32 h1:W6apQkHrMkS0Muv8G/TipAy/FJl/rCYT0+EuS8+Z0z4= @@ -568,6 +578,8 @@ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFR github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= +github.com/rs/cors v1.9.0 h1:l9HGsTsHJcvW14Nk7J9KFz8bzeAWXn3CG6bgt7LsrAE= +github.com/rs/cors v1.9.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= @@ -602,8 +614,9 @@ github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMB github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= -github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d h1:yKm7XZV6j9Ev6lojP2XaIshpT4ymkqhMeSghO5Ps00E= github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE= @@ -674,6 +687,8 @@ github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9dec github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= +github.com/zitadel/oidc/v2 v2.6.3 h1:YY87cAcdI+3voZqcRU2RGz3Pxky/2KsjDmYDVb6EgWw= +github.com/zitadel/oidc/v2 v2.6.3/go.mod h1:2LrbdKYLSgKxXBfct56ev4e186J7TXotlZxb6tExOO4= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.etcd.io/etcd/api/v3 v3.5.4/go.mod h1:5GB2vv4A4AOn3yk7MftYGHkUfGtDHnEraIjym4dYz5A= go.etcd.io/etcd/client/pkg/v3 v3.5.4/go.mod h1:IJHfcCEKxYu1Os13ZdwCwIUTUVGYTSAM3YSwc9/Ac1g= @@ -915,8 +930,8 @@ google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlba google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= -google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.29.1 h1:7QBf+IK2gx70Ap/hDsOmam3GE0v9HicjfEdAxE62UoM= +google.golang.org/protobuf v1.29.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc h1:2gGKlE2+asNV9m7xrywl36YYNnBG5ZQ0r/BOOxqPpmk= gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc/go.mod h1:m7x9LTH6d71AHyAX77c9yqWCCa3UKHcVEj9y7hAtKDk= @@ -931,6 +946,8 @@ gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df/go.mod h1:LRQQ+SO6ZHR7tOkp gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= gopkg.in/square/go-jose.v2 v2.3.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= +gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI= +gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/backend/handler/oidc.go b/backend/handler/oidc.go new file mode 100644 index 000000000..097586b57 --- /dev/null +++ b/backend/handler/oidc.go @@ -0,0 +1,144 @@ +package handler + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "github.com/gofrs/uuid" + "github.com/labstack/echo/v4" + "github.com/lestrrat-go/jwx/v2/jwt" + auditlog "github.com/teamhanko/hanko/backend/audit_log" + "github.com/teamhanko/hanko/backend/config" + "github.com/teamhanko/hanko/backend/handler/oidc" + "github.com/teamhanko/hanko/backend/persistence" + "github.com/teamhanko/hanko/backend/session" + "github.com/zitadel/oidc/v2/pkg/op" + "golang.org/x/text/language" + "net/http" +) + +type OIDCHandler struct { + cfg *config.Config + persister persistence.Persister + sessionManager session.Manager + auditLogger auditlog.Logger + provider op.OpenIDProvider +} + +func NewOIDCHandler( + cfg *config.Config, + persister persistence.Persister, + sessionManager session.Manager, + auditLogger auditlog.Logger, +) *OIDCHandler { + if !cfg.OIDC.Enabled { + return nil + } + + key, err := base64.URLEncoding.DecodeString(cfg.OIDC.Key) + if err != nil { + panic(err) + } + + if len(key) != 32 { + panic("key must be 32 bytes long") + } + + pathLoggedOut := "/logged_out" + + var extraOptions []op.Option + + config := &op.Config{ + CryptoKey: [32]byte(key), + + // will be used if the end_session endpoint is called without a post_logout_redirect_uri + DefaultLogoutRedirectURI: pathLoggedOut, + + // enables code_challenge_method S256 for PKCE (and therefore PKCE in general) + CodeMethodS256: true, + + // enables additional client_id/client_secret authentication by form post (not only HTTP Basic Auth) + AuthMethodPost: true, + + // enables additional authentication by using private_key_jwt + AuthMethodPrivateKeyJWT: false, + + // enables refresh_token grant use + GrantTypeRefreshToken: true, + + // enables use of the `request` Object parameter + RequestObjectSupported: true, + + // this example has only static texts (in English), so we'll set the here accordingly + SupportedUILocales: []language.Tag{language.English}, + } + + storage := oidc.NewStorage(persister) + for _, client := range cfg.OIDC.Clients { + err := storage.AddClient(&client) + if err != nil { + panic(err) + } + } + + provider, err := op.NewOpenIDProvider(cfg.OIDC.Issuer, config, storage, append([]op.Option{ + op.WithCustomEndpoints( + op.NewEndpoint("/oauth/authorize"), + op.NewEndpoint("/oauth/token"), + op.NewEndpoint("/oauth/userinfo"), + op.NewEndpoint("/oauth/revoke"), + op.NewEndpoint("/oauth/end_session"), + op.NewEndpoint("/oauth/keys"), + ), + op.WithCustomDeviceAuthorizationEndpoint(op.NewEndpoint("/oauth/device_authorization")), + }, extraOptions...)...) + if err != nil { + panic(err) + } + + fmt.Println("OIDC provider initialized") + f := op.AuthCallbackURL(provider) + fmt.Println("OIDC callback url:", f(context.Background(), "testID")) + fmt.Println("OIDC callback url:") + + return &OIDCHandler{ + cfg: cfg, + persister: persister, + sessionManager: sessionManager, + auditLogger: auditLogger, + provider: provider, + } +} + +func (h *OIDCHandler) Handler(c echo.Context) error { + h.provider.HttpHandler().ServeHTTP(c.Response(), c.Request()) + + return nil +} + +func (h *OIDCHandler) LoginHandler(c echo.Context) error { + sessionToken, ok := c.Get("session").(jwt.Token) + if !ok { + return errors.New("failed to cast session object") + } + + authRequestID := c.QueryParam("id") + if authRequestID == "" { + return c.String(400, "id parameter missing") + } + + uid, err := uuid.FromString(authRequestID) + if err != nil { + return c.String(400, "id parameter invalid") + } + + persister := h.persister.GetOIDCAuthRequestPersister() + + err = persister.AuthorizeUser(c.Request().Context(), uid, sessionToken.Subject()) + if err != nil { + return c.String(500, "error authorizing user") + } + + return c.Redirect(http.StatusFound, "/oauth/authorize/callback?id="+authRequestID) +} diff --git a/backend/handler/oidc/client.go b/backend/handler/oidc/client.go new file mode 100644 index 000000000..f07a50a55 --- /dev/null +++ b/backend/handler/oidc/client.go @@ -0,0 +1,224 @@ +package oidc + +import ( + "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v2/pkg/op" + "time" +) + +var ( + // we use the default login UI and pass the (auth request) id + defaultLoginURL = func(id string) string { + return "/login/username?authRequestID=" + id + } + + // clients to be used by the storage interface + clients = map[string]*Client{} +) + +// Client represents the storage model of an OAuth/OIDC client +type Client struct { + id string + secret string + redirectURIs []string + applicationType op.ApplicationType + authMethod oidc.AuthMethod + loginURL func(string) string + responseTypes []oidc.ResponseType + grantTypes []oidc.GrantType + accessTokenType op.AccessTokenType + devMode bool + idTokenUserinfoClaimsAssertion bool + clockSkew time.Duration + postLogoutRedirectURIGlobs []string + redirectURIGlobs []string + + keys map[string]interface{} +} + +// GetID must return the client_id +func (c *Client) GetID() string { + return c.id +} + +// RedirectURIs must return the registered redirect_uris for Code and Implicit Flow +func (c *Client) RedirectURIs() []string { + return c.redirectURIs +} + +// PostLogoutRedirectURIs must return the registered post_logout_redirect_uris for sign-outs +func (c *Client) PostLogoutRedirectURIs() []string { + return []string{} +} + +// ApplicationType must return the type of the client (app, native, user agent) +func (c *Client) ApplicationType() op.ApplicationType { + return c.applicationType +} + +// AuthMethod must return the authentication method (client_secret_basic, client_secret_post, none, private_key_jwt) +func (c *Client) AuthMethod() oidc.AuthMethod { + return c.authMethod +} + +// ResponseTypes must return all allowed response types (code, id_token token, id_token) +// these must match with the allowed grant types +func (c *Client) ResponseTypes() []oidc.ResponseType { + return c.responseTypes +} + +// GrantTypes must return all allowed grant types (authorization_code, refresh_token, urn:ietf:params:oauth:grant-type:jwt-bearer) +func (c *Client) GrantTypes() []oidc.GrantType { + return c.grantTypes +} + +// LoginURL will be called to redirect the user (agent) to the login UI +// you could implement some logic here to redirect the users to different login UIs depending on the client +func (c *Client) LoginURL(id string) string { + return c.loginURL(id) +} + +// AccessTokenType must return the type of access token the client uses (Bearer (opaque) or JWT) +func (c *Client) AccessTokenType() op.AccessTokenType { + return c.accessTokenType +} + +// IDTokenLifetime must return the lifetime of the client's id_tokens +func (c *Client) IDTokenLifetime() time.Duration { + return 1 * time.Hour +} + +// DevMode enables the use of non-compliant configs such as redirect_uris (e.g. http schema for user agent client) +func (c *Client) DevMode() bool { + return c.devMode +} + +// RestrictAdditionalIdTokenScopes allows specifying which custom scopes shall be asserted into the id_token +func (c *Client) RestrictAdditionalIdTokenScopes() func(scopes []string) []string { + return func(scopes []string) []string { + return scopes + } +} + +// RestrictAdditionalAccessTokenScopes allows specifying which custom scopes shall be asserted into the JWT access_token +func (c *Client) RestrictAdditionalAccessTokenScopes() func(scopes []string) []string { + return func(scopes []string) []string { + return scopes + } +} + +// IsScopeAllowed enables Client specific custom scopes validation +// in this example we allow the CustomScope for all clients +func (c *Client) IsScopeAllowed(scope string) bool { + return false +} + +// IDTokenUserinfoClaimsAssertion allows specifying if claims of scope profile, email, phone and address are asserted into the id_token +// even if an access token if issued which violates the OIDC Core spec +// (5.4. Requesting Claims using Scope Values: https://openid.net/specs/openid-connect-core-1_0.html#ScopeClaims) +// some clients though require that e.g. email is always in the id_token when requested even if an access_token is issued +func (c *Client) IDTokenUserinfoClaimsAssertion() bool { + return c.idTokenUserinfoClaimsAssertion +} + +// ClockSkew enables clients to instruct the OP to apply a clock skew on the various times and expirations +// (subtract from issued_at, add to expiration, ...) +func (c *Client) ClockSkew() time.Duration { + return c.clockSkew +} + +func (c *Client) GetKey(keyID string) (interface{}, bool) { + k, ok := c.keys[keyID] + return k, ok +} + +// RegisterClients enables you to register clients for the example implementation +// there are some clients (web and native) to try out different cases +// add more if necessary +// +// RegisterClients should be called before the Storage is used so that there are +// no race conditions. +func RegisterClients(registerClients ...*Client) { + for _, client := range registerClients { + clients[client.id] = client + } +} + +// NativeClient will create a client of type native, which will always use PKCE and allow the use of refresh tokens +// user-defined redirectURIs may include: +// - http://localhost without port specification (e.g. http://localhost/auth/callback) +// - custom protocol (e.g. custom://auth/callback) +// (the examples will be used as default, if none is provided) +func NativeClient(id string, redirectURIs ...string) *Client { + if len(redirectURIs) == 0 { + redirectURIs = []string{ + "http://localhost/auth/callback", + "custom://auth/callback", + } + } + + return &Client{ + id: id, + secret: "", // no secret needed (due to PKCE) + redirectURIs: redirectURIs, + applicationType: op.ApplicationTypeNative, + authMethod: oidc.AuthMethodNone, + loginURL: defaultLoginURL, + responseTypes: []oidc.ResponseType{oidc.ResponseTypeCode}, + grantTypes: []oidc.GrantType{oidc.GrantTypeCode, oidc.GrantTypeRefreshToken}, + accessTokenType: op.AccessTokenTypeBearer, + devMode: false, + idTokenUserinfoClaimsAssertion: false, + clockSkew: 0, + } +} + +// WebClient will create a client of type web, which will always use Basic Auth and allow the use of refresh tokens +// user-defined redirectURIs may include: +// - http://localhost with port specification (e.g. http://localhost:9999/auth/callback) +// (the example will be used as default, if none is provided) +func WebClient(id, secret string, redirectURIs ...string) *Client { + if len(redirectURIs) == 0 { + redirectURIs = []string{ + "http://localhost:9999/auth/callback", + } + } + + return &Client{ + id: id, + secret: secret, + redirectURIs: redirectURIs, + applicationType: op.ApplicationTypeWeb, + authMethod: oidc.AuthMethodBasic, + loginURL: defaultLoginURL, + responseTypes: []oidc.ResponseType{oidc.ResponseTypeCode}, + grantTypes: []oidc.GrantType{oidc.GrantTypeCode, oidc.GrantTypeRefreshToken}, + accessTokenType: op.AccessTokenTypeBearer, + devMode: false, + idTokenUserinfoClaimsAssertion: false, + clockSkew: 0, + } +} + +type hasRedirectGlobs struct { + *Client +} + +// RedirectURIGlobs provide wildcarding for additional valid redirects +func (c hasRedirectGlobs) RedirectURIGlobs() []string { + return c.redirectURIGlobs +} + +// PostLogoutRedirectURIGlobs provide extra wildcarding for additional valid redirects +func (c hasRedirectGlobs) PostLogoutRedirectURIGlobs() []string { + return c.postLogoutRedirectURIGlobs +} + +// RedirectGlobsClient wraps the client in a op.HasRedirectGlobs +// only if DevMode is enabled. +func RedirectGlobsClient(client *Client) op.Client { + if client.devMode { + return hasRedirectGlobs{client} + } + return client +} diff --git a/backend/handler/oidc/oidc.go b/backend/handler/oidc/oidc.go new file mode 100644 index 000000000..4d84f75c4 --- /dev/null +++ b/backend/handler/oidc/oidc.go @@ -0,0 +1,202 @@ +package oidc + +import ( + "errors" + "github.com/gofrs/uuid" + "github.com/teamhanko/hanko/backend/persistence/models" + "github.com/zitadel/oidc/v2/pkg/oidc" + "golang.org/x/text/language" + "strings" + "time" +) + +type AuthRequest struct { + ID uuid.UUID + CreationDate time.Time + ApplicationID string + CallbackURI string + TransferState string + Prompt []string + UiLocales []language.Tag + LoginHint string + MaxAuthAge *time.Duration + UserID string + Scopes []string + ResponseType oidc.ResponseType + Nonce string + CodeChallenge string + LoginDone bool + AuthTime time.Time +} + +func NewAuthRequestFromModel(request *models.AuthRequest) (*AuthRequest, error) { + if request == nil { + return nil, errors.New("auth request not found") + } + + var uiLocales []language.Tag + for _, tag := range request.GetUILocales() { + uiLocales = append(uiLocales, language.Make(tag)) + } + + maxAuthAge := request.GetMaxAuthAge() + + return &AuthRequest{ + ID: request.ID, + CreationDate: request.CreatedAt, + ApplicationID: request.ClientID, + CallbackURI: request.CallbackURI, + TransferState: request.TransferState, + Prompt: request.GetPrompt(), + UiLocales: uiLocales, + LoginHint: request.LoginHint, + MaxAuthAge: &maxAuthAge, + UserID: request.UserID, + Scopes: request.GetScopes(), + ResponseType: oidc.ResponseType(request.ResponseType), + Nonce: request.Nonce, + CodeChallenge: request.CodeChallenge, + LoginDone: request.Done, + AuthTime: request.AuthTime, + }, nil +} + +func (a *AuthRequest) GetID() string { + return a.ID.String() +} + +func (a *AuthRequest) GetACR() string { + return "" // we won't handle acr +} + +func (a *AuthRequest) GetAMR() []string { + // TODO: https://www.rfc-editor.org/rfc/rfc8176.html + + // this example only uses password for authentication + if a.LoginDone { + return []string{"pwd"} + } + return nil +} + +func (a *AuthRequest) GetAudience() []string { + return []string{a.ApplicationID} // this example will always just use the client_id as audience +} + +func (a *AuthRequest) GetAuthTime() time.Time { + return a.AuthTime +} + +func (a *AuthRequest) GetClientID() string { + return a.ApplicationID +} + +func (a *AuthRequest) GetCodeChallenge() *oidc.CodeChallenge { + return &oidc.CodeChallenge{ + Challenge: a.CodeChallenge, + Method: oidc.CodeChallengeMethodS256, + } +} + +func (a *AuthRequest) GetNonce() string { + return a.Nonce +} + +func (a *AuthRequest) GetRedirectURI() string { + return a.CallbackURI +} + +func (a *AuthRequest) GetResponseType() oidc.ResponseType { + return a.ResponseType +} + +func (a *AuthRequest) GetResponseMode() oidc.ResponseMode { + return "" // we won't handle response mode +} + +func (a *AuthRequest) GetScopes() []string { + return a.Scopes +} + +func (a *AuthRequest) GetState() string { + return a.TransferState +} + +func (a *AuthRequest) GetSubject() string { + return a.UserID +} + +func (a *AuthRequest) Done() bool { + return a.LoginDone +} + +func (a *AuthRequest) ToModel() models.AuthRequest { + var locales []string + for _, locale := range a.UiLocales { + locales = append(locales, locale.String()) + } + + var maxAuthAge time.Duration + if a.MaxAuthAge != nil { + maxAuthAge = *a.MaxAuthAge + } + + return models.AuthRequest{ + ID: a.ID, + CreatedAt: a.CreationDate, + ClientID: a.ApplicationID, + CallbackURI: a.CallbackURI, + TransferState: a.TransferState, + Prompt: strings.Join(a.Prompt, ","), + UILocales: strings.Join(locales, ","), + LoginHint: a.LoginHint, + MaxAuthAge: int64(maxAuthAge.Seconds()), + UserID: a.UserID, + Scopes: strings.Join(a.Scopes, ","), + ResponseType: string(a.ResponseType), + Nonce: a.Nonce, + CodeChallenge: a.CodeChallenge, + } +} + +func PromptToInternal(oidcPrompt oidc.SpaceDelimitedArray) []string { + prompts := make([]string, len(oidcPrompt)) + for _, oidcPrompt := range oidcPrompt { + switch oidcPrompt { + case oidc.PromptNone, + oidc.PromptLogin, + oidc.PromptConsent, + oidc.PromptSelectAccount: + prompts = append(prompts, oidcPrompt) + } + } + + return prompts +} + +func MaxAgeToInternal(maxAge *uint) *time.Duration { + if maxAge == nil { + return nil + } + + dur := time.Duration(*maxAge) * time.Second + return &dur +} + +func authRequestToInternal(authReq *oidc.AuthRequest, userID string) *AuthRequest { + return &AuthRequest{ + CreationDate: time.Now(), + ApplicationID: authReq.ClientID, + CallbackURI: authReq.RedirectURI, + TransferState: authReq.State, + Prompt: PromptToInternal(authReq.Prompt), + UiLocales: authReq.UILocales, + LoginHint: authReq.LoginHint, + MaxAuthAge: MaxAgeToInternal(authReq.MaxAge), + UserID: userID, + Scopes: authReq.Scopes, + ResponseType: authReq.ResponseType, + Nonce: authReq.Nonce, + CodeChallenge: authReq.CodeChallenge, + } +} diff --git a/backend/handler/oidc/storage.go b/backend/handler/oidc/storage.go new file mode 100644 index 000000000..094e6e334 --- /dev/null +++ b/backend/handler/oidc/storage.go @@ -0,0 +1,768 @@ +package oidc + +import ( + "context" + "fmt" + "github.com/gofrs/uuid" + "github.com/teamhanko/hanko/backend/config" + "github.com/teamhanko/hanko/backend/persistence" + "github.com/teamhanko/hanko/backend/persistence/models" + "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v2/pkg/op" + "gopkg.in/square/go-jose.v2" + "strings" + "sync" + "time" +) + +type Storage struct { + lock sync.RWMutex + + clients map[string]*Client + accessTokenExpiration time.Duration + refreshTokenExpiration time.Duration + + accessTokens persistence.OIDCAccessTokenPersister + refreshTokens persistence.OIDCRefreshTokenPersister + authRequests persistence.OIDCAuthRequestPersister + keys persistence.OIDCKeyPersister + users persistence.UserPersister +} + +func NewStorage(persister persistence.Persister) *Storage { + return &Storage{ + accessTokens: persister.GetOIDCAccessTokenPersister(), + refreshTokens: persister.GetOIDCRefreshTokenPersister(), + authRequests: persister.GetOIDCAuthRequestPersister(), + keys: persister.GetOIDCKeyPersister(), + users: persister.GetUserPersister(), + clients: make(map[string]*Client), + accessTokenExpiration: time.Minute, + refreshTokenExpiration: time.Hour * 24 * 30, + } +} + +func (s *Storage) AddClient(client *config.OIDCClient) error { + if client.ClientType == "native" { + s.clients[client.ClientID] = NativeClient(client.ClientID, client.RedirectURI...) + + return nil + } + + if client.ClientType == "web" { + s.clients[client.ClientID] = WebClient(client.ClientID, client.ClientSecret, client.RedirectURI...) + + return nil + } + + return fmt.Errorf("unknown client type: %s", client.ClientType) +} + +// CreateAuthRequest implements the op.Storage interface +// it will be called after parsing and validation of the authentication request +func (s *Storage) CreateAuthRequest(ctx context.Context, req *oidc.AuthRequest, userID string) (op.AuthRequest, error) { + fmt.Println("storage: CreateAuthRequest") + + if len(req.Prompt) == 1 && req.Prompt[0] == "none" { + // With prompt=none, there is no way for the user to log in + // so return error right away. + return nil, oidc.ErrLoginRequired() + } + + // typically, you'll fill your storage / storage model with the information of the passed object + request := authRequestToInternal(req, userID) + + // you'll also have to create a unique id for the request (this might be done by your database; we'll use a uuid) + uid, err := uuid.NewV4() + if err != nil { + return nil, fmt.Errorf("failed to generate uuid: %w", err) + } + + request.ID = uid + + fmt.Println("request: ", request) + fmt.Println("request userID: ", userID) + + // and save it in your database (for demonstration purposed we will use a simple map) + err = s.authRequests.Create(ctx, request.ToModel()) + if err != nil { + return nil, err + } + + return request, nil +} + +// AuthRequestByID implements the op.Storage interface +// it will be called after the Login UI redirects back to the OIDC endpoint +func (s *Storage) AuthRequestByID(ctx context.Context, id string) (op.AuthRequest, error) { + fmt.Println("storage: AuthRequestByID") + + uid, err := uuid.FromString(id) + if err != nil { + return nil, fmt.Errorf("failed parse uuid: %w", err) + } + + request, err := s.authRequests.Get(ctx, uid) + if err != nil { + return nil, fmt.Errorf("could not get auth request: %w", err) + } + + fmt.Println("request: ", request) + + return NewAuthRequestFromModel(request) +} + +// AuthRequestByCode implements the op.Storage interface +// it will be called after parsing and validation of the token request (in an authorization code flow) +func (s *Storage) AuthRequestByCode(ctx context.Context, code string) (op.AuthRequest, error) { + fmt.Println("storage: AuthRequestByCode") + + request, err := s.authRequests.GetAuthRequestByCode(ctx, code) + if err != nil { + return nil, fmt.Errorf("could not get auth request by code: %w", err) + } + + return NewAuthRequestFromModel(request) +} + +// SaveAuthCode implements the op.Storage interface +// it will be called after the authentication has been successful and before redirecting the user agent to the +// redirect_uri (in an authorization code flow) +func (s *Storage) SaveAuthCode(ctx context.Context, id string, code string) error { + fmt.Println("storage: SaveAuthCode") + + uid, err := uuid.FromString(id) + if err != nil { + return fmt.Errorf("failed parse uuid: %w", err) + } + + err = s.authRequests.StoreAuthCode(ctx, uid, code) + if err != nil { + return fmt.Errorf("could not store auth code: %w", err) + } + + return nil +} + +// DeleteAuthRequest implements the op.Storage interface +// it will be called after creating the token response (id and access tokens) for a valid +// - authentication request (in an implicit flow) +// - token request (in an authorization code flow) +func (s *Storage) DeleteAuthRequest(ctx context.Context, id string) error { + fmt.Println("storage: DeleteAuthRequest") + + uid, err := uuid.FromString(id) + if err != nil { + return fmt.Errorf("failed parse uuid: %w", err) + } + + err = s.authRequests.Delete(ctx, uid) + if err != nil { + return fmt.Errorf("could not delete auth request: %w", err) + } + + return nil +} + +// createAccessToken will store an access_token in-memory based on the provided information +func (s *Storage) createAccessToken(ctx context.Context, clientID, subject string, refreshTokenID uuid.UUID, audience, scopes []string) (*models.AccessToken, error) { + fmt.Println("storage: createAccessToken") + + uid, err := uuid.NewV4() + if err != nil { + return nil, fmt.Errorf("failed to generate uuid: %w", err) + } + + var refreshToken *models.RefreshToken + if refreshTokenID != uuid.Nil { + refreshToken = &models.RefreshToken{ID: refreshTokenID} + } + + token := models.AccessToken{ + ID: uid, + ClientID: clientID, + RefreshToken: refreshToken, + Subject: subject, + Audience: strings.Join(audience, ","), + ExpiresAt: time.Now().Add(s.accessTokenExpiration), + Scopes: strings.Join(scopes, ","), + } + + err = s.accessTokens.Create(ctx, token) + if err != nil { + return nil, err + } + + return &token, nil +} + +// createRefreshToken will store a refresh_token in-memory based on the provided information +func (s *Storage) createRefreshToken(ctx context.Context, id uuid.UUID, clientID string, subject string, audience []string, scopes []string, amr []string, authTime time.Time) (*models.RefreshToken, error) { + fmt.Println("storage: createRefreshToken") + + token := models.RefreshToken{ + ID: id, + AuthTime: authTime, + AMR: strings.Join(amr, ","), + ClientID: clientID, + UserID: subject, + Audience: strings.Join(audience, ","), + ExpiresAt: time.Now().Add(s.refreshTokenExpiration), + Scopes: strings.Join(scopes, ","), + } + + err := s.refreshTokens.Create(ctx, token) + if err != nil { + return nil, err + } + + return &token, err +} + +// renewRefreshToken checks the provided refresh_token and creates a new one based on the current +func (s *Storage) renewRefreshToken(ctx context.Context, clientID, currentRefreshToken string) (*models.RefreshToken, error) { + fmt.Println("storage: renewRefreshToken") + + uid, err := uuid.FromString(currentRefreshToken) + if err != nil { + return nil, fmt.Errorf("failed to parse uuid: %w", err) + } + + token, err := s.refreshTokens.Get(ctx, uid) + if err != nil { + return nil, fmt.Errorf("failed to get refresh token: %w", err) + } + + if token.ClientID != clientID { + return nil, op.ErrInvalidRefreshToken + } + + // deletes the refresh token and all access tokens which were issued based on this refresh token + err = s.refreshTokens.Delete(ctx, *token) + if err != nil { + return nil, fmt.Errorf("failed to delete refresh token: %w", err) + } + + // creates a new refresh token based on the current one + uid, err = uuid.NewV4() + if err != nil { + return nil, fmt.Errorf("failed to generate uuid: %w", err) + } + + token.ID = uid + + err = s.refreshTokens.Create(ctx, *token) + if err != nil { + return nil, fmt.Errorf("failed to create refresh token: %w", err) + } + + return token, nil +} + +func (s *Storage) exchangeRefreshToken(ctx context.Context, request op.TokenExchangeRequest) (accessTokenID string, newRefreshToken string, expiration time.Time, err error) { + fmt.Println("storage: exchangeRefreshToken") + + applicationID := request.GetClientID() + authTime := request.GetAuthTime() + + refreshTokenID, err := uuid.NewV4() + if err != nil { + return "", "", time.Time{}, fmt.Errorf("failed to generate uuid: %w", err) + } + + refreshToken, err := s.createRefreshToken(ctx, refreshTokenID, applicationID, request.GetSubject(), request.GetAudience(), request.GetScopes(), request.GetAMR(), authTime) + if err != nil { + return "", "", time.Time{}, err + } + + accessToken, err := s.createAccessToken(ctx, applicationID, request.GetSubject(), refreshTokenID, request.GetAudience(), request.GetScopes()) + if err != nil { + return "", "", time.Time{}, err + } + + return accessToken.ID.String(), refreshToken.ID.String(), accessToken.ExpiresAt, nil +} + +// CreateAccessToken implements the op.Storage interface +// it will be called for all requests able to return an access token (Authorization Code Flow, Implicit Flow, JWT Profile, ...) +func (s *Storage) CreateAccessToken(ctx context.Context, request op.TokenRequest) (accessTokenID string, expiration time.Time, err error) { + fmt.Println("storage: CreateAccessToken") + + var applicationID string + switch req := request.(type) { + case *AuthRequest: + // if authenticated for an app (auth code / implicit flow) we must save the client_id to the token + applicationID = req.ApplicationID + case op.TokenExchangeRequest: + applicationID = req.GetClientID() + default: + panic("invalid state encountered") + } + + token, err := s.createAccessToken(ctx, applicationID, request.GetSubject(), uuid.Nil, request.GetAudience(), request.GetScopes()) + if err != nil { + return "", time.Time{}, err + } + + return token.ID.String(), token.ExpiresAt, nil +} + +// CreateAccessAndRefreshTokens implements the op.Storage interface +// it will be called for all requests able to return an access and refresh token (Authorization Code Flow, Refresh Token Request) +func (s *Storage) CreateAccessAndRefreshTokens(ctx context.Context, request op.TokenRequest, currentRefreshToken string) (accessTokenID string, newRefreshTokenID string, expiration time.Time, err error) { + fmt.Println("storage: CreateAccessAndRefreshTokens") + + // generate tokens via token exchange flow if request is relevant + if teReq, ok := request.(op.TokenExchangeRequest); ok { + return s.exchangeRefreshToken(ctx, teReq) + } + + // get the information depending on the request type / implementation + applicationID, authTime, amr := getInfoFromRequest(request) + + // if currentRefreshToken is empty (Code Flow) we will have to create a new refresh token + if currentRefreshToken == "" { + refreshTokenID, err := uuid.NewV4() + if err != nil { + return "", "", time.Time{}, fmt.Errorf("failed to generate uuid: %w", err) + } + + refreshToken, err := s.createRefreshToken(ctx, refreshTokenID, applicationID, request.GetSubject(), request.GetAudience(), request.GetScopes(), amr, authTime) + if err != nil { + return "", "", time.Time{}, err + } + + accessToken, err := s.createAccessToken(ctx, applicationID, request.GetSubject(), refreshTokenID, request.GetAudience(), request.GetScopes()) + if err != nil { + return "", "", time.Time{}, err + } + + return accessToken.ID.String(), refreshToken.ID.String(), accessToken.ExpiresAt, nil + } + + // if we get here, the currentRefreshToken was not empty, so the call is a refresh token request + // we therefore will have to check the currentRefreshToken and renew the refresh token + refreshToken, err := s.renewRefreshToken(ctx, applicationID, currentRefreshToken) + if err != nil { + return "", "", time.Time{}, err + } + + accessToken, err := s.createAccessToken(ctx, applicationID, request.GetSubject(), refreshToken.ID, request.GetAudience(), request.GetScopes()) + if err != nil { + return "", "", time.Time{}, err + } + + return accessToken.ID.String(), refreshToken.ID.String(), accessToken.ExpiresAt, nil +} + +// TokenRequestByRefreshToken implements the op.Storage interface +// it will be called after parsing and validation of the refresh token request +func (s *Storage) TokenRequestByRefreshToken(ctx context.Context, refreshTokenID string) (op.RefreshTokenRequest, error) { + fmt.Println("storage: TokenRequestByRefreshToken") + + uid, err := uuid.FromString(refreshTokenID) + if err != nil { + return nil, fmt.Errorf("failed to parse refresh token id: %w", err) + } + + refreshToken, err := s.refreshTokens.Get(ctx, uid) + if err != nil { + return nil, fmt.Errorf("failed to get refresh token: %w", err) + } + + return RefreshTokenRequestFromBusiness(refreshToken), nil +} + +// TerminateSession implements the op.Storage interface +// it will be called after the user signed out, therefore the access and refresh token of the user of this client must be removed +func (s *Storage) TerminateSession(ctx context.Context, userID string, clientID string) error { + fmt.Println("storage: TerminateSession") + + err := s.refreshTokens.TerminateSessions(ctx, userID, clientID) + if err != nil { + return fmt.Errorf("error terminating session: %w", err) + } + + return nil +} + +// RevokeToken implements the op.Storage interface +// it will be called after parsing and validation of the token revocation request +func (s *Storage) RevokeToken(ctx context.Context, tokenOrTokenID string, userID string, clientID string) *oidc.Error { + fmt.Println("storage: RevokeToken") + + uid, err := uuid.FromString(tokenOrTokenID) + if err != nil { + return oidc.ErrInvalidRequest().WithDescription("invalid accessToken") + } + + accessToken, err := s.accessTokens.Get(ctx, uid) + if err == nil && accessToken != nil { + if accessToken.ClientID != clientID { + return oidc.ErrInvalidClient().WithDescription("accessToken was not issued for this client") + } + + err = s.accessTokens.Delete(ctx, *accessToken) + if err != nil { + return oidc.ErrServerError().WithDescription(err.Error()) + } + + return nil + } + + refreshToken, err := s.refreshTokens.Get(ctx, uid) + if err == nil && refreshToken == nil { + // if the token is neither an access nor a refresh token, just ignore it, the expected behaviour of + // being not valid (anymore) is achieved + return nil + } + + if err != nil { + return oidc.ErrServerError().WithDescription("failed to get refreshToken") + } + + if accessToken.ClientID != clientID { + return oidc.ErrInvalidClient().WithDescription("refreshToken was not issued for this client") + } + + // This should also take care of deleting the access token + err = s.refreshTokens.Delete(ctx, *refreshToken) + if err != nil { + return oidc.ErrServerError().WithDescription(err.Error()) + } + + return nil +} + +// GetRefreshTokenInfo looks up a refresh token and returns the token id and user id. +// If given something that is not a refresh token, it must return error. +func (s *Storage) GetRefreshTokenInfo(ctx context.Context, clientID string, tokenStr string) (userID string, tokenID string, err error) { + fmt.Println("storage: GetRefreshTokenInfo") + + uid, err := uuid.FromString(tokenStr) + if err != nil { + return "", "", op.ErrInvalidRefreshToken + } + + token, err := s.refreshTokens.Get(ctx, uid) + if err == nil && token == nil { + return "", "", op.ErrInvalidRefreshToken + } + + if err != nil { + return "", "", fmt.Errorf("failed to get refresh token: %w", err) + } + + if token.ClientID != clientID { + return "", "", op.ErrInvalidRefreshToken + } + + return token.UserID, token.ID.String(), nil +} + +// SigningKey implements the op.Storage interface +// it will be called when creating the OpenID Provider +func (s *Storage) SigningKey(ctx context.Context) (op.SigningKey, error) { + fmt.Println("storage: SigningKey") + + key, err := s.keys.GetSigningKey(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get signing key: %w", err) + } + + if key == nil { + return nil, fmt.Errorf("no signing key found") + } + + return key, nil +} + +// SignatureAlgorithms implements the op.Storage interface +// it will be called to get the sign +func (s *Storage) SignatureAlgorithms(ctx context.Context) ([]jose.SignatureAlgorithm, error) { + fmt.Println("storage: SignatureAlgorithms") + + key, err := s.keys.GetSigningKey(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get signing key: %w", err) + } + + if key == nil { + return nil, fmt.Errorf("no signing key found") + } + + return []jose.SignatureAlgorithm{key.SignatureAlgorithm()}, nil +} + +// KeySet implements the op.Storage interface +// it will be called to get the current (public) keys, among others for the keys_endpoint or for validating access_tokens on the userinfo_endpoint, ... +func (s *Storage) KeySet(ctx context.Context) ([]op.Key, error) { + fmt.Println("storage: KeySet") + + keys, err := s.keys.GetPublicKeys(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get signing keys: %w", err) + } + + var opKeys []op.Key + for _, key := range keys { + opKeys = append(opKeys, &key) + } + + return opKeys, nil +} + +// GetClientByClientID implements the op.Storage interface +// it will be called whenever information (type, redirect_uris, ...) about the client behind the client_id is needed +func (s *Storage) GetClientByClientID(ctx context.Context, clientID string) (op.Client, error) { + fmt.Println("storage: GetClientByClientID") + + s.lock.RLock() + defer s.lock.RUnlock() + + fmt.Println("storage: GetClientByClientID: clientID: ", clientID) + fmt.Println(s.clients) + + client, ok := s.clients[clientID] + if !ok { + return nil, oidc.ErrInvalidClient() + } + + return client, nil +} + +// AuthorizeClientIDSecret implements the op.Storage interface +// it will be called for validating the client_id, client_secret on token or introspection requests +func (s *Storage) AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string) error { + fmt.Println("storage: AuthorizeClientIDSecret") + + s.lock.RLock() + defer s.lock.RUnlock() + + client, ok := s.clients[clientID] + if !ok { + return oidc.ErrInvalidClient() + } + + if client.secret != clientSecret { + return oidc.ErrUnauthorizedClient() + } + + return nil +} + +// SetUserinfoFromScopes implements the op.Storage interface. +// Provide an empty implementation and use SetUserinfoFromRequest instead. +func (s *Storage) SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error { + fmt.Println("storage: SetUserinfoFromScopes") + + return nil +} + +// setUserinfo sets the info based on the user, scopes and if necessary the clientID +func (s *Storage) setUserinfo(ctx context.Context, userInfo *oidc.UserInfo, userID, clientID string, scopes []string) (err error) { + uid, err := uuid.FromString(userID) + if err != nil { + return fmt.Errorf("invalid userID") + } + + user, err := s.users.Get(uid) + if err != nil { + return fmt.Errorf("failed to get user: %w", err) + } + + if user == nil { + return fmt.Errorf("user not found") + } + + for _, scope := range scopes { + switch scope { + case oidc.ScopeOpenID: + userInfo.Subject = user.ID.String() + case oidc.ScopeEmail: + primaryEmail := user.Emails.GetPrimary() + if primaryEmail == nil { + return fmt.Errorf("no primary email found") + } + + userInfo.Email = primaryEmail.Address + userInfo.EmailVerified = oidc.Bool(primaryEmail.Verified) + /* + case oidc.ScopeProfile: + userInfo.PreferredUsername = user.Username + userInfo.Name = user.FirstName + " " + user.LastName + userInfo.FamilyName = user.LastName + userInfo.GivenName = user.FirstName + userInfo.Locale = oidc.NewLocale(user.PreferredLanguage) + case oidc.ScopePhone: + userInfo.PhoneNumber = user.Phone + userInfo.PhoneNumberVerified = user.PhoneVerified + */ + } + } + + return nil +} + +// SetUserinfoFromToken implements the op.Storage interface +// it will be called for the userinfo endpoint, so we read the token and pass the information from that to the private function +func (s *Storage) SetUserinfoFromToken(ctx context.Context, userinfo *oidc.UserInfo, tokenID, subject, origin string) error { + fmt.Println("storage: SetUserinfoFromToken") + + uid, err := uuid.FromString(tokenID) + if err != nil { + return fmt.Errorf("failed to parse token id: %w", err) + } + + token, err := s.accessTokens.Get(ctx, uid) + if err != nil { + return fmt.Errorf("failed to get token: %w", err) + } + + if token == nil { + return fmt.Errorf("token not found") + } + + if token.ExpiresAt.Before(time.Now()) { + return fmt.Errorf("token has expired") + } + + return s.setUserinfo(ctx, userinfo, token.Subject, token.ClientID, token.GetScopes()) +} + +// SetIntrospectionFromToken implements the op.Storage interface +// it will be called for the introspection endpoint, so we read the token and pass the information from that to the private function +func (s *Storage) SetIntrospectionFromToken(ctx context.Context, userinfo *oidc.IntrospectionResponse, tokenID, subject, clientID string) error { + fmt.Println("storage: SetIntrospectionFromToken") + + uid, err := uuid.FromString(tokenID) + if err != nil { + return fmt.Errorf("failed to parse token id: %w", err) + } + + token, err := s.accessTokens.Get(ctx, uid) + if err != nil { + return fmt.Errorf("failed to get token: %w", err) + } + + if token == nil { + return fmt.Errorf("token not found") + } + + if token.ExpiresAt.Before(time.Now()) { + return fmt.Errorf("token has expired") + } + + for _, aud := range token.GetAudience() { + if aud == clientID { + // the introspection response only has to return a boolean (active) if the token is active + // this will automatically be done by the library if you don't return an error + // you can also return further information about the user / associated token + // e.g. the userinfo (equivalent to userinfo endpoint) + + userInfo := new(oidc.UserInfo) + err := s.setUserinfo(ctx, userInfo, subject, clientID, token.GetScopes()) + if err != nil { + return err + } + + userinfo.SetUserInfo(userInfo) + //...and also the requested scopes... + userinfo.Scope = token.GetScopes() + //...and the client the token was issued to + userinfo.ClientID = token.ClientID + + return nil + } + } + + return fmt.Errorf("token is not valid for this client") +} + +func (s *Storage) getPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (claims map[string]interface{}, err error) { + fmt.Println("storage: getPrivateClaimsFromScopes") + + for _, scope := range scopes { + switch scope { + } + } + return claims, nil +} + +// GetPrivateClaimsFromScopes implements the op.Storage interface +// it will be called for the creation of a JWT access token to assert claims for custom scopes +func (s *Storage) GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (map[string]interface{}, error) { + fmt.Println("storage: GetPrivateClaimsFromScopes") + + return s.getPrivateClaimsFromScopes(ctx, userID, clientID, scopes) +} + +// GetKeyByIDAndClientID implements the op.Storage interface +// it will be called to validate the signatures of a JWT (JWT Profile Grant and Authentication) +func (s *Storage) GetKeyByIDAndClientID(ctx context.Context, keyID, clientID string) (*jose.JSONWebKey, error) { + fmt.Println("storage: GetKeyByIDAndClientID") + + s.lock.RLock() + defer s.lock.RUnlock() + + client, ok := s.clients[clientID] + if !ok { + return nil, fmt.Errorf("clientID not found") + } + + key, ok := client.GetKey(keyID) + if !ok { + return nil, fmt.Errorf("key not found") + } + + return &jose.JSONWebKey{ + KeyID: keyID, + Use: "sig", + Key: key, + }, nil +} + +// ValidateJWTProfileScopes implements the op.Storage interface +// it will be called to validate the scopes of a JWT Profile Authorization Grant request +func (s *Storage) ValidateJWTProfileScopes(ctx context.Context, userID string, scopes []string) ([]string, error) { + fmt.Println("storage: ValidateJWTProfileScopes") + + allowedScopes := make([]string, 0) + for _, scope := range scopes { + if scope == oidc.ScopeOpenID { + allowedScopes = append(allowedScopes, scope) + } + } + return allowedScopes, nil +} + +func (s *Storage) Health(ctx context.Context) error { + //TODO implement me + panic("implement me") +} + +// SetUserinfoFromRequest implements the op.CanSetUserinfoFromRequest interface. In the +// next major release, it will be required for op.Storage. +// It will be called for the creation of an id_token, so we'll just pass it to the private function without any further check +func (s *Storage) SetUserinfoFromRequest(ctx context.Context, userinfo *oidc.UserInfo, token op.IDTokenRequest, scopes []string) error { + fmt.Println("storage: SetUserinfoFromRequest") + + return s.setUserinfo(ctx, userinfo, token.GetSubject(), token.GetClientID(), scopes) +} + +// getInfoFromRequest returns the clientID, authTime and amr depending on the op.TokenRequest type / implementation +func getInfoFromRequest(req op.TokenRequest) (clientID string, authTime time.Time, amr []string) { + fmt.Println("storage: getInfoFromRequest") + + authReq, ok := req.(*AuthRequest) // Code Flow (with scope offline_access) + if ok { + return authReq.ApplicationID, authReq.GetAuthTime(), authReq.GetAMR() + } + + refreshReq, ok := req.(*RefreshTokenRequest) // Refresh Token Request + if ok { + return refreshReq.ClientID, refreshReq.AuthTime, refreshReq.GetAMR() + } + + return "", time.Time{}, nil +} diff --git a/backend/handler/oidc/token.go b/backend/handler/oidc/token.go new file mode 100644 index 000000000..ac775addd --- /dev/null +++ b/backend/handler/oidc/token.go @@ -0,0 +1,45 @@ +package oidc + +import ( + "github.com/teamhanko/hanko/backend/persistence/models" + "github.com/zitadel/oidc/v2/pkg/op" + "strings" + "time" +) + +// RefreshTokenRequestFromBusiness will simply wrap the storage RefreshToken to implement the op.RefreshTokenRequest interface +func RefreshTokenRequestFromBusiness(token *models.RefreshToken) op.RefreshTokenRequest { + return &RefreshTokenRequest{token} +} + +type RefreshTokenRequest struct { + *models.RefreshToken +} + +func (r *RefreshTokenRequest) GetAMR() []string { + return r.GetAMR() +} + +func (r *RefreshTokenRequest) GetAudience() []string { + return r.GetAudience() +} + +func (r *RefreshTokenRequest) GetAuthTime() time.Time { + return r.AuthTime +} + +func (r *RefreshTokenRequest) GetClientID() string { + return r.ClientID +} + +func (r *RefreshTokenRequest) GetScopes() []string { + return r.GetScopes() +} + +func (r *RefreshTokenRequest) GetSubject() string { + return r.UserID +} + +func (r *RefreshTokenRequest) SetCurrentScopes(scopes []string) { + r.Scopes = strings.Join(scopes, ",") +} diff --git a/backend/handler/oidc_test.go b/backend/handler/oidc_test.go new file mode 100644 index 000000000..64a59abc7 --- /dev/null +++ b/backend/handler/oidc_test.go @@ -0,0 +1,166 @@ +package handler + +import ( + "encoding/json" + "github.com/gofrs/uuid" + "github.com/stretchr/testify/suite" + "github.com/teamhanko/hanko/backend/crypto/jwk" + "github.com/teamhanko/hanko/backend/session" + "github.com/teamhanko/hanko/backend/test" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" +) + +func TestOIDCSuite(t *testing.T) { + suite.Run(t, new(oidcSuite)) +} + +type oidcSuite struct { + test.Suite +} + +func (s *oidcSuite) TestOIDCHandler_Paths() { + cfg := &test.DefaultConfig + cfg.OIDC.Enabled = true + + persister := test.NewPersister(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + handlers := NewOIDCHandler(cfg, persister, nil, nil) + + s.Equal("/oauth/authorize", handlers.provider.AuthorizationEndpoint().Relative()) + s.Equal("/oauth/device_authorization", handlers.provider.DeviceAuthorizationEndpoint().Relative()) + s.Equal("/oauth/end_session", handlers.provider.EndSessionEndpoint().Relative()) + s.Equal("/oauth/introspect", handlers.provider.IntrospectionEndpoint().Relative()) + s.Equal("/oauth/keys", handlers.provider.KeysEndpoint().Relative()) + s.Equal("/oauth/revoke", handlers.provider.RevocationEndpoint().Relative()) + s.Equal("/oauth/token", handlers.provider.TokenEndpoint().Relative()) + s.Equal("/oauth/userinfo", handlers.provider.UserinfoEndpoint().Relative()) +} + +func (s *oidcSuite) TestOIDCHandler_well_known() { + cfg := &test.DefaultConfig + cfg.OIDC.Enabled = true + + req := httptest.NewRequest(http.MethodGet, "/.well-known/openid-configuration", nil) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + e := NewPublicRouter(cfg, s.Storage, nil) + e.ServeHTTP(rec, req) + + s.Equal(http.StatusOK, rec.Code) + + body, err := io.ReadAll(rec.Body) + s.Require().NoError(err) + + var data map[string]interface{} + err = json.Unmarshal(body, &data) + + s.Equal("https://example.hanko.io", data["issuer"]) + s.Equal("https://example.hanko.io/oauth/authorize", data["authorization_endpoint"]) +} + +func (s *oidcSuite) TestOIDCHandler_authorize() { + cfg := &test.DefaultConfig + cfg.OIDC.Enabled = true + + err := s.LoadFixtures("../test/fixtures/oidc") + s.Require().NoError(err) + + // Request generator: https://zitadel.com/docs/apis/openidoauth/authrequest + path := "/oauth/authorize?client_id=19286ac4-2216-44dd-bb21-02a41ea3548d&redirect_uri=http%3A%2F%2Flocalhost%3A8080%2Fcallback&response_type=code&scope=openid%20email%20profile%20offline_access&code_challenge=iMnq5o6zALKXGivsnlom_0F5_WYda32GHkxlV7mq7hQ&code_challenge_method=S256" + req := httptest.NewRequest(http.MethodGet, path, nil) + rev := httptest.NewRecorder() + + e := NewPublicRouter(cfg, s.Storage, nil) + e.ServeHTTP(rev, req) + + s.Equal(http.StatusFound, rev.Code) + + // TODO: check redirect location + s.True(strings.HasPrefix(rev.Header().Get("Location"), "/login/username?authRequestID=")) + + uri, err := url.Parse(rev.Header().Get("Location")) + s.Require().NoError(err) + + authRequestId := uri.Query().Get("authRequestID") + + // This is the tricky bit now - we need to simulate a login flow + // Because Hanko has no built-in redirects in the login flow - we might need to do this client side or add a + // redirect parameter somewhere with a custom redirect function. + path = "/oauth/login?id=" + authRequestId + req = httptest.NewRequest(http.MethodGet, path, nil) + rev = httptest.NewRecorder() + + jwkManager, err := jwk.NewDefaultManager(test.DefaultConfig.Secrets.Keys, s.Storage.GetJwkPersister()) + s.Require().NoError(err) + sessionManager, err := session.NewManager(jwkManager, test.DefaultConfig) + s.Require().NoError(err) + token, err := sessionManager.GenerateJWT(uuid.FromStringOrNil("b5dd5267-b462-48be-b70d-bcd6f1bbe7a5")) + s.Require().NoError(err) + cookie, err := sessionManager.GenerateCookie(token) + s.Require().NoError(err) + req.AddCookie(cookie) + + e.ServeHTTP(rev, req) + + s.Equal(http.StatusFound, rev.Code) + + // Let's follow the redirect from login + path = rev.Header().Get("Location") + req = httptest.NewRequest(http.MethodGet, path, nil) + rev = httptest.NewRecorder() + + e.ServeHTTP(rev, req) + + s.Equal(http.StatusFound, rev.Code) + s.True(strings.HasPrefix(rev.Header().Get("Location"), "http://localhost:8080/callback?code=")) + + uri, err = url.Parse(rev.Header().Get("Location")) + s.Require().NoError(err) + + code := uri.Query().Get("code") + + // This is now back with the client - let's simulate the token exchange + path = "/oauth/token?grant_type=authorization_code&code=" + code + "&redirect_uri=http%3A%2F%2Flocalhost%3A8080%2Fcallback" + req = httptest.NewRequest(http.MethodGet, path, nil) + req.SetBasicAuth("19286ac4-2216-44dd-bb21-02a41ea3548d", "104cff48ae574505874884973de1f2488b8cd56ea55fdd45b2649a071af94617") + rev = httptest.NewRecorder() + + e.ServeHTTP(rev, req) + + s.Equal(http.StatusOK, rev.Code) + + var data map[string]interface{} + err = json.Unmarshal(rev.Body.Bytes(), &data) + s.Require().NoError(err) + + s.NotEmpty(data["access_token"]) + s.NotEmpty(data["refresh_token"]) + + // And let's also check out the userinfo endpoint + path = "/oauth/introspect" + req = httptest.NewRequest(http.MethodPost, path, strings.NewReader("token="+data["access_token"].(string))) + req.SetBasicAuth("19286ac4-2216-44dd-bb21-02a41ea3548d", "104cff48ae574505874884973de1f2488b8cd56ea55fdd45b2649a071af94617") + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rev = httptest.NewRecorder() + + e.ServeHTTP(rev, req) + + s.Equal(http.StatusOK, rev.Code) + + var introspect map[string]interface{} + + err = json.Unmarshal(rev.Body.Bytes(), &introspect) + s.Require().NoError(err) + + s.Equal(introspect["active"], true) + s.Equal(introspect["scope"], "openid email profile offline_access") + s.Equal(introspect["client_id"], "19286ac4-2216-44dd-bb21-02a41ea3548d") + s.Equal(introspect["sub"], "b5dd5267-b462-48be-b70d-bcd6f1bbe7a5") + s.Equal(introspect["email"], "john.doe@example.com") + s.Equal(introspect["email_verified"], true) +} diff --git a/backend/handler/passcode_test.go b/backend/handler/passcode_test.go index dcd577b8a..a0d0bd4a1 100644 --- a/backend/handler/passcode_test.go +++ b/backend/handler/passcode_test.go @@ -19,13 +19,13 @@ import ( ) func TestNewPasscodeHandler(t *testing.T) { - passcodeHandler, err := NewPasscodeHandler(&config.Config{}, test.NewPersister(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil), sessionManager{}, mailer{}, test.NewAuditLogger()) + passcodeHandler, err := NewPasscodeHandler(&config.Config{}, test.NewPersister(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil), sessionManager{}, mailer{}, test.NewAuditLogger()) assert.NoError(t, err) assert.NotEmpty(t, passcodeHandler) } func TestPasscodeHandler_Init(t *testing.T) { - passcodeHandler, err := NewPasscodeHandler(&config.Config{}, test.NewPersister(users, nil, nil, nil, nil, nil, nil, emails, nil, nil, nil), sessionManager{}, mailer{}, test.NewAuditLogger()) + passcodeHandler, err := NewPasscodeHandler(&config.Config{}, test.NewPersister(users, nil, nil, nil, nil, nil, nil, emails, nil, nil, nil, nil, nil, nil, nil, nil), sessionManager{}, mailer{}, test.NewAuditLogger()) require.NoError(t, err) body := dto.PasscodeInitRequest{ @@ -47,7 +47,7 @@ func TestPasscodeHandler_Init(t *testing.T) { } func TestPasscodeHandler_Init_UnknownUserId(t *testing.T) { - passcodeHandler, err := NewPasscodeHandler(&config.Config{}, test.NewPersister(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil), sessionManager{}, mailer{}, test.NewAuditLogger()) + passcodeHandler, err := NewPasscodeHandler(&config.Config{}, test.NewPersister(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil), sessionManager{}, mailer{}, test.NewAuditLogger()) require.NoError(t, err) body := dto.PasscodeInitRequest{ @@ -71,7 +71,7 @@ func TestPasscodeHandler_Init_UnknownUserId(t *testing.T) { } func TestPasscodeHandler_Finish(t *testing.T) { - passcodeHandler, err := NewPasscodeHandler(&config.Config{}, test.NewPersister(users, passcodes(), nil, nil, nil, nil, nil, nil, nil, nil, nil), sessionManager{}, mailer{}, test.NewAuditLogger()) + passcodeHandler, err := NewPasscodeHandler(&config.Config{}, test.NewPersister(users, passcodes(), nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil), sessionManager{}, mailer{}, test.NewAuditLogger()) require.NoError(t, err) body := dto.PasscodeFinishRequest{ @@ -94,7 +94,7 @@ func TestPasscodeHandler_Finish(t *testing.T) { } func TestPasscodeHandler_Finish_WrongCode(t *testing.T) { - passcodeHandler, err := NewPasscodeHandler(&config.Config{}, test.NewPersister(nil, passcodes(), nil, nil, nil, nil, nil, nil, nil, nil, nil), sessionManager{}, mailer{}, test.NewAuditLogger()) + passcodeHandler, err := NewPasscodeHandler(&config.Config{}, test.NewPersister(nil, passcodes(), nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil), sessionManager{}, mailer{}, test.NewAuditLogger()) require.NoError(t, err) body := dto.PasscodeFinishRequest{ @@ -119,7 +119,7 @@ func TestPasscodeHandler_Finish_WrongCode(t *testing.T) { } func TestPasscodeHandler_Finish_WrongCode_3_Times(t *testing.T) { - passcodeHandler, err := NewPasscodeHandler(&config.Config{}, test.NewPersister(nil, passcodes(), nil, nil, nil, nil, nil, nil, nil, nil, nil), sessionManager{}, mailer{}, test.NewAuditLogger()) + passcodeHandler, err := NewPasscodeHandler(&config.Config{}, test.NewPersister(nil, passcodes(), nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil), sessionManager{}, mailer{}, test.NewAuditLogger()) require.NoError(t, err) body := dto.PasscodeFinishRequest{ @@ -153,7 +153,7 @@ func TestPasscodeHandler_Finish_WrongCode_3_Times(t *testing.T) { } func TestPasscodeHandler_Finish_WrongId(t *testing.T) { - passcodeHandler, err := NewPasscodeHandler(&config.Config{}, test.NewPersister(nil, passcodes(), nil, nil, nil, nil, nil, nil, nil, nil, nil), sessionManager{}, mailer{}, test.NewAuditLogger()) + passcodeHandler, err := NewPasscodeHandler(&config.Config{}, test.NewPersister(nil, passcodes(), nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil), sessionManager{}, mailer{}, test.NewAuditLogger()) require.NoError(t, err) body := dto.PasscodeFinishRequest{ diff --git a/backend/handler/public_router.go b/backend/handler/public_router.go index cbc3ab6c8..90ca9fee9 100644 --- a/backend/handler/public_router.go +++ b/backend/handler/public_router.go @@ -111,6 +111,24 @@ func NewPublicRouter(cfg *config.Config, persister persistence.Persister, promet wellKnown.GET("/jwks.json", wellKnownHandler.GetPublicKeys) wellKnown.GET("/config", wellKnownHandler.GetConfig) + if cfg.OIDC.Enabled { + oidcHandler := NewOIDCHandler(cfg, persister, sessionManager, auditLogger) + oidc := e.Group("/oauth") + oidc.Any("/authorize", oidcHandler.Handler) + oidc.Any("/authorize/callback", oidcHandler.Handler) + oidc.Any("/device_authorization", oidcHandler.Handler) + oidc.Any("/end_session", oidcHandler.Handler) + oidc.Any("/introspect", oidcHandler.Handler) + oidc.Any("/keys", oidcHandler.Handler) + oidc.Any("/revoke", oidcHandler.Handler) + oidc.Any("/token", oidcHandler.Handler) + oidc.Any("/userinfo", oidcHandler.Handler) + + oidc.GET("/login", oidcHandler.LoginHandler, hankoMiddleware.Session(sessionManager)) + + wellKnown.GET("/openid-configuration", oidcHandler.Handler) + } + emailHandler, err := NewEmailHandler(cfg, persister, sessionManager, auditLogger) if err != nil { panic(fmt.Errorf("failed to create public email handler: %w", err)) diff --git a/backend/persistence/migrations/20230623113054_create_oidc_tokens.down.fizz b/backend/persistence/migrations/20230623113054_create_oidc_tokens.down.fizz new file mode 100644 index 000000000..9f6952fd2 --- /dev/null +++ b/backend/persistence/migrations/20230623113054_create_oidc_tokens.down.fizz @@ -0,0 +1,2 @@ +drop_table("access_tokens") +drop_table("refresh_tokens") diff --git a/backend/persistence/migrations/20230623113054_create_oidc_tokens.up.fizz b/backend/persistence/migrations/20230623113054_create_oidc_tokens.up.fizz new file mode 100644 index 000000000..7dc2b17f9 --- /dev/null +++ b/backend/persistence/migrations/20230623113054_create_oidc_tokens.up.fizz @@ -0,0 +1,25 @@ +create_table("refresh_tokens") { + t.Column("id", "uuid", {}) + t.Column("client_id", "string", {}) + t.Column("user_id", "string", {}) + t.Column("audience", "string", {}) + t.Column("amr", "string", {}) + t.Column("scopes", "string", {}) + t.Column("auth_time", "timestamp", {}) + t.Column("expires_at", "timestamp", {}) + t.Timestamps() + t.PrimaryKey("id") +} + +create_table("access_tokens") { + t.Column("id", "uuid", {}) + t.Column("refresh_token_id", "uuid", {"null": true}) + t.Column("client_id", "string", {}) + t.Column("subject", "string", {}) + t.Column("audience", "string", {}) + t.Column("scopes", "string", {}) + t.Column("expires_at", "timestamp", {}) + t.Timestamps() + t.PrimaryKey("id") + t.ForeignKey("refresh_token_id", {"refresh_tokens": ["id"]}, {"on_delete": "cascade"}) +} diff --git a/backend/persistence/migrations/20230623114845_create_oidc_auth_requests.down.fizz b/backend/persistence/migrations/20230623114845_create_oidc_auth_requests.down.fizz new file mode 100644 index 000000000..cca2bf56f --- /dev/null +++ b/backend/persistence/migrations/20230623114845_create_oidc_auth_requests.down.fizz @@ -0,0 +1,2 @@ +drop_table("auth_codes") +drop_table("auth_requests") diff --git a/backend/persistence/migrations/20230623114845_create_oidc_auth_requests.up.fizz b/backend/persistence/migrations/20230623114845_create_oidc_auth_requests.up.fizz new file mode 100644 index 000000000..d86476d33 --- /dev/null +++ b/backend/persistence/migrations/20230623114845_create_oidc_auth_requests.up.fizz @@ -0,0 +1,27 @@ +create_table("auth_requests") { + t.Column("id", "uuid", {}) + t.Column("client_id", "string", {}) + t.Column("callback_uri", "string", {}) + t.Column("transfer_state", "string", {}) + t.Column("prompt", "string", {}) + t.Column("ui_locales", "string", {}) + t.Column("login_hint", "string", {}) + t.Column("user_id", "string", {}) + t.Column("scopes", "string", {}) + t.Column("response_type", "string", {}) + t.Column("nonce", "string", {}) + t.Column("code_challenge", "string", {}) + t.Column("max_auth_age", "integer", {}) + t.Column("done", "bool", {}) + t.Column("auth_time", "timestamp", {}) + t.Timestamps() + t.PrimaryKey("id") +} + +create_table("auth_codes") { + t.Column("id", "string", {}) + t.Column("auth_request_id", "uuid", {}) + t.Timestamps() + t.PrimaryKey("id") + t.ForeignKey("auth_request_id", {"auth_requests": ["id"]}, {"on_delete": "cascade"}) +} diff --git a/backend/persistence/migrations/20230623115748_create_oidc_keys.down.fizz b/backend/persistence/migrations/20230623115748_create_oidc_keys.down.fizz new file mode 100644 index 000000000..b98ba22e5 --- /dev/null +++ b/backend/persistence/migrations/20230623115748_create_oidc_keys.down.fizz @@ -0,0 +1 @@ +drop_table("keys") diff --git a/backend/persistence/migrations/20230623115748_create_oidc_keys.up.fizz b/backend/persistence/migrations/20230623115748_create_oidc_keys.up.fizz new file mode 100644 index 000000000..df78e40ca --- /dev/null +++ b/backend/persistence/migrations/20230623115748_create_oidc_keys.up.fizz @@ -0,0 +1,9 @@ +create_table("keys") { + t.Column("id", "string", {}) + t.Column("algorithm", "string", {}) + t.Column("public_key", "string", {"size": 4096}) + t.Column("private_key", "string", {"size": 4096}) + t.Column("expires_at", "timestamp", {}) + t.Timestamps() + t.PrimaryKey("id") +} diff --git a/backend/persistence/models/oidc_auth_request.go b/backend/persistence/models/oidc_auth_request.go new file mode 100644 index 000000000..59ca6fd45 --- /dev/null +++ b/backend/persistence/models/oidc_auth_request.go @@ -0,0 +1,62 @@ +package models + +import ( + "github.com/gobuffalo/pop/v6" + "github.com/gobuffalo/validate/v3" + "github.com/gobuffalo/validate/v3/validators" + "github.com/gofrs/uuid" + "strings" + "time" +) + +type AuthRequest struct { + ID uuid.UUID `db:"id" json:"id"` + Codes []AuthCode `has_many:"codes" json:"codes,omitempty"` + + ClientID string `db:"client_id" json:"client_id"` + CallbackURI string `db:"callback_uri" json:"callback_uri"` + TransferState string `db:"transfer_state" json:"transfer_state"` + Prompt string `db:"prompt" json:"prompt"` + UILocales string `db:"ui_locales" json:"ui_locales"` + LoginHint string `db:"login_hint" json:"login_hint"` + UserID string `db:"user_id" json:"user_id"` + Scopes string `db:"scopes" json:"scopes"` + ResponseType string `db:"response_type" json:"response_type"` + Nonce string `db:"nonce" json:"nonce"` + CodeChallenge string `db:"code_challenge" json:"code_challenge"` + MaxAuthAge int64 `db:"max_auth_age" json:"max_auth_age"` + Done bool `db:"done" json:"done"` + AuthTime time.Time `db:"auth_time" json:"auth_time"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` +} + +func (t *AuthRequest) GetPrompt() []string { + return strings.Split(t.Prompt, ",") +} + +func (t *AuthRequest) GetUILocales() []string { + return strings.Split(t.UILocales, ",") +} + +func (t *AuthRequest) GetScopes() []string { + return strings.Split(t.Scopes, ",") +} + +func (t *AuthRequest) GetMaxAuthAge() time.Duration { + return time.Duration(t.MaxAuthAge) * time.Second +} + +func (t *AuthRequest) Validate(tx *pop.Connection) (*validate.Errors, error) { + return validate.Validate( + &validators.UUIDIsPresent{Name: "ID", Field: t.ID}, + ), nil +} + +type AuthCode struct { + ID string `db:"id" json:"id"` + AuthRequest *AuthRequest `belongs_to:"auth_request" json:"auth_request,omitempty"` + AuthRequestID uuid.UUID `db:"auth_request_id" json:"auth_request_id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` +} diff --git a/backend/persistence/models/oidc_key.go b/backend/persistence/models/oidc_key.go new file mode 100644 index 000000000..643fa4c36 --- /dev/null +++ b/backend/persistence/models/oidc_key.go @@ -0,0 +1,102 @@ +package models + +import ( + "crypto/x509" + "encoding/pem" + "github.com/gobuffalo/pop/v6" + "github.com/gobuffalo/validate/v3" + "github.com/gobuffalo/validate/v3/validators" + "github.com/gofrs/uuid" + "gopkg.in/square/go-jose.v2" + "strings" + "time" +) + +type Key struct { + ID uuid.UUID `db:"id" json:"id"` + Algo jose.SignatureAlgorithm `db:"algorithm" json:"algorithm"` + Key string `db:"public_key" json:"public_key"` + PrivateKey string `db:"private_key" json:"private_key"` + ExpiresAt time.Time `db:"expires_at" json:"expires_at"` +} + +func (k *Key) SigningKey() *SigningKey { + var key interface{} + switch k.Algo { + case jose.RS256, jose.RS384, jose.RS512: + block, _ := pem.Decode([]byte(strings.Replace(k.PrivateKey, "\\n", "\n", -1))) + if block == nil { + panic("failed to parse PEM block containing the key") + } + + priv, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + panic(err) + } + + key = priv + default: + panic("not implemented") + } + + return &SigningKey{ + keyID: k.ID, + algorithm: k.Algo, + privateKey: key, + } +} + +func (k *Key) PublicKey() PublicKey { + return PublicKey{ + keyID: k.ID, + algorithm: k.Algo, + publicKey: k.Key, + } +} + +func (k *Key) Validate(tx *pop.Connection) (*validate.Errors, error) { + return validate.Validate( + &validators.UUIDIsPresent{Name: "ID", Field: k.ID}, + &validators.StringIsPresent{Name: "Algorithm", Field: string(k.Algo)}, + ), nil +} + +type SigningKey struct { + keyID uuid.UUID + algorithm jose.SignatureAlgorithm + privateKey interface{} +} + +func (k *SigningKey) ID() string { + return k.keyID.String() +} + +func (k *SigningKey) SignatureAlgorithm() jose.SignatureAlgorithm { + return k.algorithm +} + +func (k *SigningKey) Key() interface{} { + return k.privateKey +} + +type PublicKey struct { + keyID uuid.UUID + algorithm jose.SignatureAlgorithm + publicKey interface{} +} + +func (k *PublicKey) ID() string { + return k.keyID.String() +} + +func (k *PublicKey) Key() interface{} { + return k.publicKey +} + +func (k *PublicKey) Algorithm() jose.SignatureAlgorithm { + return k.algorithm +} + +func (k *PublicKey) Use() string { + return "sig" +} diff --git a/backend/persistence/models/oidc_token.go b/backend/persistence/models/oidc_token.go new file mode 100644 index 000000000..b5d788fd0 --- /dev/null +++ b/backend/persistence/models/oidc_token.go @@ -0,0 +1,77 @@ +package models + +import ( + "github.com/gobuffalo/pop/v6" + "github.com/gobuffalo/validate/v3" + "github.com/gobuffalo/validate/v3/validators" + "github.com/gofrs/uuid" + "strings" + "time" +) + +type AccessToken struct { + ID uuid.UUID `db:"id" json:"id"` + RefreshToken *RefreshToken `belongs_to:"refresh_tokens" json:"refresh_token,omitempty"` + RefreshTokenID *uuid.UUID `db:"refresh_token_id" json:"refresh_token_id"` + + ClientID string `db:"client_id" json:"client_id"` + Subject string `db:"subject" json:"subject"` + Audience string `db:"audience" json:"audience"` + Scopes string `db:"scopes" json:"scopes"` + ExpiresAt time.Time `db:"expires_at" json:"expires_at"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` +} + +func (t *AccessToken) GetAudience() []string { + return strings.Split(t.ClientID, ",") +} + +func (t *AccessToken) GetScopes() []string { + return strings.Split(t.Scopes, ",") +} + +func (t *AccessToken) Validate(tx *pop.Connection) (*validate.Errors, error) { + return validate.Validate( + &validators.UUIDIsPresent{Name: "ID", Field: t.ID}, + &validators.StringIsPresent{Name: "ClientID", Field: t.ClientID}, + &validators.StringIsPresent{Name: "Subject", Field: t.Subject}, + &validators.TimeIsPresent{Name: "Expires At", Field: t.ExpiresAt}, + ), nil +} + +type RefreshToken struct { + ID uuid.UUID `db:"id" json:"id"` + AccessTokens []AccessToken `has_many:"access_tokens" json:"access_tokens,omitempty"` + + ClientID string `db:"client_id" json:"client_id"` + UserID string `db:"user_id" json:"user_id"` + Audience string `db:"audience" json:"audience"` + AMR string `db:"amr" json:"amr"` + Scopes string `db:"scopes" json:"scopes"` + AuthTime time.Time `db:"auth_time" json:"auth_time"` + ExpiresAt time.Time `db:"expires_at" json:"expires_at"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` +} + +func (t *RefreshToken) GetAudience() []string { + return strings.Split(t.ClientID, ",") +} + +func (t *RefreshToken) GetScopes() []string { + return strings.Split(t.Scopes, ",") +} + +func (t *RefreshToken) GetAMR() []string { + return strings.Split(t.AMR, ",") +} + +func (t *RefreshToken) Validate(tx *pop.Connection) (*validate.Errors, error) { + return validate.Validate( + &validators.UUIDIsPresent{Name: "ID", Field: t.ID}, + &validators.StringIsPresent{Name: "ClientID", Field: t.ClientID}, + &validators.TimeIsPresent{Name: "AuthTime", Field: t.AuthTime}, + &validators.TimeIsPresent{Name: "ExpiresAt", Field: t.ExpiresAt}, + ), nil +} diff --git a/backend/persistence/oidc_auth_request_persister.go b/backend/persistence/oidc_auth_request_persister.go new file mode 100644 index 000000000..ccb9478f6 --- /dev/null +++ b/backend/persistence/oidc_auth_request_persister.go @@ -0,0 +1,114 @@ +package persistence + +import ( + "context" + "database/sql" + "errors" + "fmt" + "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" + "github.com/teamhanko/hanko/backend/persistence/models" + "time" +) + +type OIDCAuthRequestPersister interface { + Get(ctx context.Context, uuid uuid.UUID) (*models.AuthRequest, error) + Create(ctx context.Context, authRequest models.AuthRequest) error + Delete(ctx context.Context, uuid uuid.UUID) error + AuthorizeUser(ctx context.Context, uuid uuid.UUID, userID string) error + + StoreAuthCode(ctx context.Context, ID uuid.UUID, code string) error + GetAuthRequestByCode(ctx context.Context, code string) (*models.AuthRequest, error) +} + +type oidcAuthRequestPersister struct { + db *pop.Connection +} + +func NewOIDCAuthRequestPersister(db *pop.Connection) OIDCAuthRequestPersister { + return &oidcAuthRequestPersister{db: db} +} + +func (p *oidcAuthRequestPersister) Get(ctx context.Context, uuid uuid.UUID) (*models.AuthRequest, error) { + authRequest := models.AuthRequest{} + err := p.db.WithContext(ctx).Find(&authRequest, uuid) + if err != nil && errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("failed to get auth request: %w", err) + } + + return &authRequest, nil +} + +func (p *oidcAuthRequestPersister) Create(ctx context.Context, authRequest models.AuthRequest) error { + vErr, err := p.db.WithContext(ctx).ValidateAndCreate(&authRequest) + if err != nil { + return fmt.Errorf("failed to store auth request: %w", err) + } + + if vErr != nil && vErr.HasAny() { + return fmt.Errorf("auth request object validation failed: %w", vErr) + } + + return nil +} + +func (p *oidcAuthRequestPersister) Delete(ctx context.Context, uuid uuid.UUID) error { + err := p.db.WithContext(ctx).Destroy(&models.AuthRequest{ID: uuid}) + if err != nil { + return fmt.Errorf("failed to delete auth request: %w", err) + } + + return nil +} + +func (p *oidcAuthRequestPersister) AuthorizeUser(ctx context.Context, uuid uuid.UUID, userID string) error { + err := p.db.WithContext(ctx).UpdateColumns(&models.AuthRequest{ + ID: uuid, + UserID: userID, + Done: true, + AuthTime: time.Now(), + }, "user_id", "done", "auth_time") + if err != nil { + return fmt.Errorf("failed to authorize user: %w", err) + } + + return nil +} + +func (p *oidcAuthRequestPersister) StoreAuthCode(ctx context.Context, ID uuid.UUID, code string) error { + mCode := models.AuthCode{ + ID: code, + AuthRequest: &models.AuthRequest{ + ID: ID, + }, + } + + vErr, err := p.db.WithContext(ctx).ValidateAndCreate(&mCode) + if err != nil { + return fmt.Errorf("failed to store auth code: %w", err) + } + + if vErr != nil && vErr.HasAny() { + return fmt.Errorf("auth code object validation failed: %w", vErr) + } + + return nil +} + +func (p *oidcAuthRequestPersister) GetAuthRequestByCode(ctx context.Context, code string) (*models.AuthRequest, error) { + authCode := models.AuthCode{} + + err := p.db.WithContext(ctx).EagerPreload().Find(&authCode, code) + if err != nil && errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + + if err != nil { + return nil, fmt.Errorf("failed to get auth code: %w", err) + } + + return authCode.AuthRequest, nil +} diff --git a/backend/persistence/oidc_keys_persister.go b/backend/persistence/oidc_keys_persister.go new file mode 100644 index 000000000..bf8c4ff26 --- /dev/null +++ b/backend/persistence/oidc_keys_persister.go @@ -0,0 +1,55 @@ +package persistence + +import ( + "context" + "database/sql" + "errors" + "fmt" + "github.com/gobuffalo/pop/v6" + "github.com/teamhanko/hanko/backend/persistence/models" + "time" +) + +type OIDCKeyPersister interface { + GetSigningKey(ctx context.Context) (*models.SigningKey, error) + GetPublicKeys(ctx context.Context) ([]models.PublicKey, error) +} + +type oidcKeysPersister struct { + db *pop.Connection +} + +func NewOIDCKeyPersister(db *pop.Connection) OIDCKeyPersister { + return &oidcKeysPersister{db: db} +} + +func (p *oidcKeysPersister) GetSigningKey(ctx context.Context) (*models.SigningKey, error) { + key := models.Key{} + err := p.db.WithContext(ctx).Where("expires_at > ?", time.Now()).Order("expires_at asc").First(&key) + if err != nil && errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("failed to get signing key: %w", err) + } + + return key.SigningKey(), nil +} + +func (p *oidcKeysPersister) GetPublicKeys(ctx context.Context) ([]models.PublicKey, error) { + var keys []models.Key + err := p.db.WithContext(ctx).All(&keys) + if err != nil && errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("failed to get public keys: %w", err) + } + + var publicKeys []models.PublicKey + for _, key := range keys { + publicKeys = append(publicKeys, key.PublicKey()) + } + + return publicKeys, nil +} diff --git a/backend/persistence/oidc_tokens_persister.go b/backend/persistence/oidc_tokens_persister.go new file mode 100644 index 000000000..64e05ec7f --- /dev/null +++ b/backend/persistence/oidc_tokens_persister.go @@ -0,0 +1,121 @@ +package persistence + +import ( + "context" + "database/sql" + "errors" + "fmt" + "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" + "github.com/teamhanko/hanko/backend/persistence/models" +) + +type OIDCAccessTokenPersister interface { + Get(ctx context.Context, uuid uuid.UUID) (*models.AccessToken, error) + Create(ctx context.Context, accessToken models.AccessToken) error + Delete(ctx context.Context, accessToken models.AccessToken) error +} + +type oidcAccessTokensPersister struct { + db *pop.Connection +} + +func NewOIDCAccessTokenPersister(db *pop.Connection) OIDCAccessTokenPersister { + return &oidcAccessTokensPersister{db: db} +} + +func (p *oidcAccessTokensPersister) Get(ctx context.Context, uuid uuid.UUID) (*models.AccessToken, error) { + accessToken := models.AccessToken{} + err := p.db.WithContext(ctx).Find(&accessToken, uuid) + if err != nil && errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("failed to get access token: %w", err) + } + + return &accessToken, nil +} + +func (p *oidcAccessTokensPersister) Create(ctx context.Context, accessToken models.AccessToken) error { + vErr, err := p.db.WithContext(ctx).ValidateAndCreate(&accessToken) + if err != nil { + return fmt.Errorf("failed to store access token: %w", err) + } + + if vErr != nil && vErr.HasAny() { + return fmt.Errorf("access token object validation failed: %w", vErr) + } + + return nil +} + +func (p *oidcAccessTokensPersister) Delete(ctx context.Context, accessToken models.AccessToken) error { + err := p.db.WithContext(ctx).Destroy(&accessToken) + if err != nil { + return fmt.Errorf("failed to delete access token: %w", err) + } + + return nil +} + +type OIDCRefreshTokenPersister interface { + Get(ctx context.Context, uuid uuid.UUID) (*models.RefreshToken, error) + Create(ctx context.Context, refreshToken models.RefreshToken) error + Delete(ctx context.Context, refreshToken models.RefreshToken) error + TerminateSessions(ctx context.Context, clientID string, userID string) error +} + +type oidcRefreshTokensPersister struct { + db *pop.Connection +} + +func NewOIDCRefreshTokenPersister(db *pop.Connection) OIDCRefreshTokenPersister { + return &oidcRefreshTokensPersister{db: db} +} + +func (p *oidcRefreshTokensPersister) Get(ctx context.Context, uuid uuid.UUID) (*models.RefreshToken, error) { + refreshToken := models.RefreshToken{} + err := p.db.WithContext(ctx).Find(&refreshToken, uuid) + if err != nil && errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + + if err != nil { + return nil, fmt.Errorf("failed to get refresh token: %w", err) + } + + return &refreshToken, nil +} + +func (p *oidcRefreshTokensPersister) Create(ctx context.Context, refreshToken models.RefreshToken) error { + vErr, err := p.db.WithContext(ctx).ValidateAndCreate(&refreshToken) + if err != nil { + return fmt.Errorf("failed to store refresh token: %w", err) + } + + if vErr != nil && vErr.HasAny() { + return fmt.Errorf("refresh token object validation failed: %w", vErr) + } + + return nil +} + +func (p *oidcRefreshTokensPersister) Delete(ctx context.Context, refreshToken models.RefreshToken) error { + // Slight difference: we not only need to delete the RefreshToken - we also need to delete the associated AccessToken + err := p.db.WithContext(ctx).Destroy(&refreshToken) + if err != nil { + return fmt.Errorf("failed to delete refresh token: %w", err) + } + + return nil +} + +func (p *oidcRefreshTokensPersister) TerminateSessions(ctx context.Context, clientID string, userID string) error { + err := p.db.WithContext(ctx).RawQuery("DELETE FROM sessions WHERE client_id = ? AND user_id = ?", clientID, userID).Exec() + if err != nil { + return fmt.Errorf("failed to terminate sessions: %w", err) + } + + return nil +} diff --git a/backend/persistence/persister.go b/backend/persistence/persister.go index dab6b9e4c..8bd1daeb1 100644 --- a/backend/persistence/persister.go +++ b/backend/persistence/persister.go @@ -39,6 +39,14 @@ type Persister interface { GetPrimaryEmailPersisterWithConnection(tx *pop.Connection) PrimaryEmailPersister GetTokenPersister() TokenPersister GetTokenPersisterWithConnection(tx *pop.Connection) TokenPersister + GetOIDCAccessTokenPersister() OIDCAccessTokenPersister + GetOIDCAccessTokenPersisterWithConnection(tx *pop.Connection) OIDCAccessTokenPersister + GetOIDCRefreshTokenPersister() OIDCRefreshTokenPersister + GetOIDCRefreshTokenPersisterWithConnection(tx *pop.Connection) OIDCRefreshTokenPersister + GetOIDCKeyPersister() OIDCKeyPersister + GetOIDCKeyPersisterWithConnection(tx *pop.Connection) OIDCKeyPersister + GetOIDCAuthRequestPersister() OIDCAuthRequestPersister + GetOIDCAuthRequestPersisterWithConnection(tx *pop.Connection) OIDCAuthRequestPersister } type Migrator interface { @@ -204,3 +212,35 @@ func (p *persister) GetTokenPersister() TokenPersister { func (p *persister) GetTokenPersisterWithConnection(tx *pop.Connection) TokenPersister { return NewTokenPersister(tx) } + +func (p *persister) GetOIDCAccessTokenPersister() OIDCAccessTokenPersister { + return NewOIDCAccessTokenPersister(p.DB) +} + +func (p *persister) GetOIDCAccessTokenPersisterWithConnection(tx *pop.Connection) OIDCAccessTokenPersister { + return NewOIDCAccessTokenPersister(tx) +} + +func (p *persister) GetOIDCRefreshTokenPersister() OIDCRefreshTokenPersister { + return NewOIDCRefreshTokenPersister(p.DB) +} + +func (p *persister) GetOIDCRefreshTokenPersisterWithConnection(tx *pop.Connection) OIDCRefreshTokenPersister { + return NewOIDCRefreshTokenPersister(tx) +} + +func (p *persister) GetOIDCKeyPersister() OIDCKeyPersister { + return NewOIDCKeyPersister(p.DB) +} + +func (p *persister) GetOIDCKeyPersisterWithConnection(tx *pop.Connection) OIDCKeyPersister { + return NewOIDCKeyPersister(tx) +} + +func (p *persister) GetOIDCAuthRequestPersister() OIDCAuthRequestPersister { + return NewOIDCAuthRequestPersister(p.DB) +} + +func (p *persister) GetOIDCAuthRequestPersisterWithConnection(tx *pop.Connection) OIDCAuthRequestPersister { + return NewOIDCAuthRequestPersister(tx) +} diff --git a/backend/test/config.go b/backend/test/config.go index 49d55508c..1dc21ca68 100644 --- a/backend/test/config.go +++ b/backend/test/config.go @@ -25,4 +25,17 @@ var DefaultConfig = config.Config{ SameSite: "none", }, }, + OIDC: config.OIDC{ + Enabled: false, + Issuer: "https://example.hanko.io", + Key: "gXK9jVVoRw6m85-XJHdSapaOPnBeifcJ6xcUxC-pJFk=", + Clients: []config.OIDCClient{ + { + ClientID: "19286ac4-2216-44dd-bb21-02a41ea3548d", + ClientSecret: "104cff48ae574505874884973de1f2488b8cd56ea55fdd45b2649a071af94617", + ClientType: "web", + RedirectURI: []string{"http://localhost:8080/callback"}, + }, + }, + }, } diff --git a/backend/test/fixtures/oidc/emails.yaml b/backend/test/fixtures/oidc/emails.yaml new file mode 100644 index 000000000..2587d05a3 --- /dev/null +++ b/backend/test/fixtures/oidc/emails.yaml @@ -0,0 +1,6 @@ +- id: 51b7c175-ceb6-45ba-aae6-0092221c1b84 + user_id: b5dd5267-b462-48be-b70d-bcd6f1bbe7a5 + address: john.doe@example.com + verified: true + created_at: 2020-12-31 23:59:59 + updated_at: 2020-12-31 23:59:59 diff --git a/backend/test/fixtures/oidc/keys.yaml b/backend/test/fixtures/oidc/keys.yaml new file mode 100644 index 000000000..678f12498 --- /dev/null +++ b/backend/test/fixtures/oidc/keys.yaml @@ -0,0 +1,7 @@ +- id: 742dcfe2-64db-4dbc-97b1-23c5802bcc14 + algorithm: RS256 + public_key: -----BEGIN PUBLIC KEY-----\nMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAq4jLWfkJkuL2qdIx+tNL\nWsCOt4eC+2CW2CGz04fGQgJXGafJ9ElifpGMaz3RDsL9XBqpoA0ie6hw/r85FlvC\nVKmeePIvHVQ9eI3ms10ZuNYtIr1DpZK5UxpxoZAoxjWzBKqWsQGApFpanu9Wd/iM\nMICRzu31hG10Z532K0LjYCbhb1cOZE0RZp5helh34e1SksddbqzBcHCueRudxENV\n82sG2tchp1oyW9YaaLUNJu1jAlUlDY7N5cWK68t3ZIYBfOKhD4ltFZS4NP50pWns\nf9+c1piT2agYvt2ATBwcHV8hwrBU0qFaY0R6vVvhwSydqQzTusWYPwhaIs5FAT6X\nS6tM4VTyT3bzYdQ17oTBrT8kdepg2oZhk6Yi51x2smMbxoybWUs/tS9xpoXgJYcR\ntDdnDhQstYM6zNqq1CIWothmWzdtA9bHC5zcG/yB3w/B1f4fIQUUE+fKCPIejurJ\nj0eJQp2XE+AaoqxX7OYE5AAGQGFTQY/lhqXxIxkaIConAZlqC5IDhS4IMoopb9fr\n123+0TDggowFfuhI6EpNKglGlIdf5YmyfmvGyj9ia7z44YrP/NVB7jkonJJEWYe/\nFg0PeHkMc90D1h+RgnWQnSx6c2pTGAPiJkEiyoB3ECaSeKMdVBJBdRsJHNWcld0i\nT2XvZZQC0q3kDGFDHG6F7a0CAwEAAQ==\n-----END PUBLIC KEY-----\n + private_key: -----BEGIN RSA PRIVATE KEY-----\nMIIJKgIBAAKCAgEAq4jLWfkJkuL2qdIx+tNLWsCOt4eC+2CW2CGz04fGQgJXGafJ\n9ElifpGMaz3RDsL9XBqpoA0ie6hw/r85FlvCVKmeePIvHVQ9eI3ms10ZuNYtIr1D\npZK5UxpxoZAoxjWzBKqWsQGApFpanu9Wd/iMMICRzu31hG10Z532K0LjYCbhb1cO\nZE0RZp5helh34e1SksddbqzBcHCueRudxENV82sG2tchp1oyW9YaaLUNJu1jAlUl\nDY7N5cWK68t3ZIYBfOKhD4ltFZS4NP50pWnsf9+c1piT2agYvt2ATBwcHV8hwrBU\n0qFaY0R6vVvhwSydqQzTusWYPwhaIs5FAT6XS6tM4VTyT3bzYdQ17oTBrT8kdepg\n2oZhk6Yi51x2smMbxoybWUs/tS9xpoXgJYcRtDdnDhQstYM6zNqq1CIWothmWzdt\nA9bHC5zcG/yB3w/B1f4fIQUUE+fKCPIejurJj0eJQp2XE+AaoqxX7OYE5AAGQGFT\nQY/lhqXxIxkaIConAZlqC5IDhS4IMoopb9fr123+0TDggowFfuhI6EpNKglGlIdf\n5YmyfmvGyj9ia7z44YrP/NVB7jkonJJEWYe/Fg0PeHkMc90D1h+RgnWQnSx6c2pT\nGAPiJkEiyoB3ECaSeKMdVBJBdRsJHNWcld0iT2XvZZQC0q3kDGFDHG6F7a0CAwEA\nAQKCAgEAiALuowfJpJOcXClUAfuaS3pVb4bev+3ljbijev20oVBzud8GTlIF7DAC\ndGJOqvLHrElj6ImhpwV3mzcK0ASwASuBgYse+pV6LGXv4JbYt2vz3BDQW7AMjK1y\nHlZNTmTz7qZI2E9FrowKQO2r1XLZzfeUJc0fGQMlAqgIsmLWIb6SkBMqUTOesYyx\n5C3T2OuxxrqYBhKrSzm9zj+siBuQQnPBurJMeAMX3SPWSuIMbKpcEFRtDeQGtMM/\nFRp/L9Dlyx1z8frY6PzEVxjQavyjTv2CwdG3oiUcgfLmMPM1A8ET5uikSWMxZXa+\nD/mT9vUmig2msPjOcGRx/BksNAFqOoOt8o/+h8TSYjaPvBtrvHNnIvlkF2DdjB+1\nSaR6B64vmV2os+Zki7mdHzNp0fADMGdjTckvB96KccSSxAxHid0HO7mzvMZodj0N\niIWXl8OKOrJ8L9Wn/zlMy9S08ycnXCyal8gTQ4LKxCRcylGhkw821zqlWs910aIu\nQlBGe4gmZB57glR1fOfmiczKqoGMCu5HOtwyb/MNErO9IE/r1z4gLg49BXFjo7uT\nivH3VEYco3SBc1T50gopFOlEsZWwOGCG5mJ+3+vPcDVz/RwzarITEEJlcJpmI81a\njP6SwPwyO+hJNsLyRxtWGOIK4KTkquBgypQqvhoeV07zaX8B/O0CggEBAN/RvBZj\nNiHO0ff2+zxwCfxyveGgrOm5BVMJ5if4H0hHOdOM/ZU++ehfu3nNaotQUuFn4CT/\nx3/vb1l7ykfm8PQ9heffnRPOzsS9odltHlgGKoAStBernqCyuKdY2OZbfUMzMuPU\nYjHJNpq1KrWM6Vx4A1ujMZPVJpQhR2Fet90hTIgWqUYoFTbB5mCzDRHBi+sg8LYN\n/HKCvcRj3M/Sy68IefJ+eCiSYzJVm1FvANxixMceX7jwJiTRwsb1BOISZrlALSiU\n0xQnIB83KgVNswJbqEFexU2yEffQAPKcHhtFZmPWYg2JNX7cOlHEjaU0+/Og2ShO\nD2plt1tIgUGzDT8CggEBAMQykteVBs19Kbd+ycRbPCfGBMjngbctdI7tO2MM6sOT\n2oTqMNsU5Qr0N8SmdNeeoyKJgbrVX2ki2+LfGYrpr2pYF+Vs+7/xlrqIQjoOpG9E\nENWo+tFv0tuldsyZPv7B1QxO6ApNSQ8Q7Rcp3IxixcALofEghRc/55b3y0pcumBa\nlTHNZZjfU9HMHpicAxzCNnRHY8GVZr7lAzgGWNUZJnuCKTsIW3C8m33nIOHAcqYi\nmbY5uxIix11+31buREPY4ldnltG3AQm8bZ5RFMD+FWp0im4J5pQBPPTPM5+pFSyr\n3o1OQEDx1naJZDxxIZNjkMFXs7Q4lPOJM9uBRFNdjhMCggEAdQG7N2T0RqZNhDkc\nzFKyFcSSWaLa4nC0RN328Uw4Zlu98kdRxjUfBokNhDaMDXqXaXkZZ55D2DD+4CPj\n8sTbkIOdPkPbZSCHXbjZJMZzx4aprzyX44v3qIDmIa5D7eFEUd4xK4O7NdW/8w7k\n3fZlhM7EyqI859DVkzj2jQOsUTD4Rmi6Y4/Oz0p8um7AVVj+YZRd4n7bS63nsQSX\nyhmkG8PtpITTIjqtGwI/6UmDhLMptgK9/fulpTf3gHVU8S63fv763K04z99IXqlD\nEXS2MXRjOJFnyh9eX1PhOvO8lXspdOX9aqAhVEmjP13mwsg5MvsSq4xraK72NQVp\ndUQ6lwKCAQEArMKLouFK+C66SSWPrSNZOPyYwf94rT+NXz1uCa4aGtVamadOFdu0\nQ40AflzEjgjWRVcnsMiqFv3m+ULSTwuutsmTYSYyF7Y3r1DEYDL8gC1DVaBSG5GH\n7nkovshCPDmZzBi/IjMjneydmMP3vHZNAuo7UwP7rZlL6BeSHozAYI/ix9PBHneo\naxh96IuYAf7RzFoAcTmJG5a02uRb3GklBaR7gcu+GOs7UAXxYlf/nGLjTx5Op42Q\nV7ecGgP8gHG9/JDusQOgGl6dd8aVq8sQOIkeS/7T1oewkTDSmEheqNM+SNcapRVO\nb1pTtRU5J3uIv3bmek9IeZna2/Jbo7zBmQKCAQEAtSFU+Tftvrh////w2Ligjdr4\nMhgd3l7NWeed8eEIAd5KV6Ke1k6hxa4YxBgQx1kifn4gVyalAXEaqv9dAsQoeflE\nrj9StcNX/kWdX0RojbN/p4QVviJHCUIkF2GM4pKDAmBIfzuwmvFJgTzzS3qKwFdc\nZsnwzyaSPi54WZooy/HEm1nWACDSvIHylS00J4sMIdPg2/WLJ1v6ZlFlQPnCPkfQ\nVv7kNV7S9itcwOJkozfmsdRoFUS6kxqPsT3cANy1etSllrHdHtNNnmMXLFdBrqOL\nYvFZNUoc173n9yMAx02/W5qSc0c77ps9d5vRT4mBpBTqnOS4u5wOmFT680sPDw==\n-----END RSA PRIVATE KEY-----\n + expires_at: 2100-12-31 23:59:59 + created_at: 2020-12-31 23:59:59 + updated_at: 2020-12-31 23:59:59 diff --git a/backend/test/fixtures/oidc/primary_emails.yaml b/backend/test/fixtures/oidc/primary_emails.yaml new file mode 100644 index 000000000..41d516011 --- /dev/null +++ b/backend/test/fixtures/oidc/primary_emails.yaml @@ -0,0 +1,5 @@ +- id: 8eaaa61b-ad65-45ac-a5b8-d7c6d301d29e + user_id: b5dd5267-b462-48be-b70d-bcd6f1bbe7a5 + email_id: 51b7c175-ceb6-45ba-aae6-0092221c1b84 + created_at: 2020-12-31 23:59:59 + updated_at: 2020-12-31 23:59:59 diff --git a/backend/test/fixtures/oidc/users.yaml b/backend/test/fixtures/oidc/users.yaml new file mode 100644 index 000000000..590d37205 --- /dev/null +++ b/backend/test/fixtures/oidc/users.yaml @@ -0,0 +1,4 @@ +- id: b5dd5267-b462-48be-b70d-bcd6f1bbe7a5 + created_at: 2020-12-31 23:59:59 + updated_at: 2020-12-31 23:59:59 + diff --git a/backend/test/oidc_auth_requests_persister.go b/backend/test/oidc_auth_requests_persister.go new file mode 100644 index 000000000..daf451acc --- /dev/null +++ b/backend/test/oidc_auth_requests_persister.go @@ -0,0 +1,98 @@ +package test + +import ( + "context" + "github.com/gofrs/uuid" + "github.com/teamhanko/hanko/backend/persistence" + "github.com/teamhanko/hanko/backend/persistence/models" + "time" +) + +func NewOidcAuthRequestsPersister(init []models.AuthRequest, codes map[string]uuid.UUID) persistence.OIDCAuthRequestPersister { + return &oidcAuthRequestsPersister{ + oidcAuthRequests: append([]models.AuthRequest{}, init...), + oidcAuthCodes: codes, + } +} + +type oidcAuthRequestsPersister struct { + oidcAuthRequests []models.AuthRequest + oidcAuthCodes map[string]uuid.UUID +} + +func (o *oidcAuthRequestsPersister) Get(ctx context.Context, uuid uuid.UUID) (*models.AuthRequest, error) { + var found *models.AuthRequest + + for _, data := range o.oidcAuthRequests { + if data.ID == uuid { + d := data + found = &d + } + } + + return found, nil +} + +func (o *oidcAuthRequestsPersister) Create(ctx context.Context, authRequest models.AuthRequest) error { + o.oidcAuthRequests = append(o.oidcAuthRequests, authRequest) + + return nil +} + +func (o *oidcAuthRequestsPersister) Delete(ctx context.Context, uuid uuid.UUID) error { + index := -1 + + for i, data := range o.oidcAuthRequests { + if data.ID == uuid { + index = i + } + } + + if index > -1 { + o.oidcAuthRequests = append(o.oidcAuthRequests[:index], o.oidcAuthRequests[index+1:]...) + } + + for code, id := range o.oidcAuthCodes { + if id == uuid { + delete(o.oidcAuthCodes, code) + } + } + + return nil +} + +func (o *oidcAuthRequestsPersister) AuthorizeUser(ctx context.Context, uuid uuid.UUID, userID string) error { + for i, data := range o.oidcAuthRequests { + if data.ID == uuid { + o.oidcAuthRequests[i].UserID = userID + o.oidcAuthRequests[i].Done = true + o.oidcAuthRequests[i].AuthTime = time.Now() + } + } + + return nil +} + +func (o *oidcAuthRequestsPersister) StoreAuthCode(ctx context.Context, ID uuid.UUID, code string) error { + o.oidcAuthCodes[code] = ID + + return nil +} + +func (o *oidcAuthRequestsPersister) GetAuthRequestByCode(ctx context.Context, code string) (*models.AuthRequest, error) { + var found *models.AuthRequest + + uid, ok := o.oidcAuthCodes[code] + if !ok { + return nil, nil + } + + for _, data := range o.oidcAuthRequests { + if data.ID == uid { + d := data + found = &d + } + } + + return found, nil +} diff --git a/backend/test/oidc_keys_persister.go b/backend/test/oidc_keys_persister.go new file mode 100644 index 000000000..72b6f471a --- /dev/null +++ b/backend/test/oidc_keys_persister.go @@ -0,0 +1,40 @@ +package test + +import ( + "context" + "github.com/teamhanko/hanko/backend/persistence" + "github.com/teamhanko/hanko/backend/persistence/models" + "time" +) + +func NewOidcKeysPersister(init []models.Key) persistence.OIDCKeyPersister { + return &oidcKeysPersister{append([]models.Key{}, init...)} +} + +type oidcKeysPersister struct { + oidcKeys []models.Key +} + +func (o *oidcKeysPersister) GetSigningKey(ctx context.Context) (*models.SigningKey, error) { + var found *models.Key + + for _, data := range o.oidcKeys { + if data.ExpiresAt.After(time.Now()) { + if found == nil || found.ExpiresAt.After(data.ExpiresAt) { + found = &data + } + } + } + + return found.SigningKey(), nil +} + +func (o *oidcKeysPersister) GetPublicKeys(ctx context.Context) ([]models.PublicKey, error) { + var found []models.PublicKey + + for _, data := range o.oidcKeys { + found = append(found, data.PublicKey()) + } + + return found, nil +} diff --git a/backend/test/oidc_tokens_persister.go b/backend/test/oidc_tokens_persister.go new file mode 100644 index 000000000..d5f27f10e --- /dev/null +++ b/backend/test/oidc_tokens_persister.go @@ -0,0 +1,103 @@ +package test + +import ( + "context" + "github.com/gofrs/uuid" + "github.com/teamhanko/hanko/backend/persistence" + "github.com/teamhanko/hanko/backend/persistence/models" +) + +func NewOidcAccessTokensPersister(init []models.AccessToken) persistence.OIDCAccessTokenPersister { + return &oidcAccessTokensPersister{append([]models.AccessToken{}, init...)} +} + +type oidcAccessTokensPersister struct { + oidcAccessTokens []models.AccessToken +} + +func (o *oidcAccessTokensPersister) Get(ctx context.Context, uuid uuid.UUID) (*models.AccessToken, error) { + var found *models.AccessToken + + for _, data := range o.oidcAccessTokens { + if data.ID == uuid { + d := data + found = &d + } + } + + return found, nil +} + +func (o *oidcAccessTokensPersister) Create(ctx context.Context, accessToken models.AccessToken) error { + o.oidcAccessTokens = append(o.oidcAccessTokens, accessToken) + + return nil +} + +func (o *oidcAccessTokensPersister) Delete(ctx context.Context, accessToken models.AccessToken) error { + index := -1 + + for i, data := range o.oidcAccessTokens { + if data.ID == accessToken.ID { + index = i + } + } + + if index > -1 { + o.oidcAccessTokens = append(o.oidcAccessTokens[:index], o.oidcAccessTokens[index+1:]...) + } + + return nil +} + +func NewOidcRefreshTokensPersister(init []models.RefreshToken) persistence.OIDCRefreshTokenPersister { + return &oidcRefreshTokensPersister{append([]models.RefreshToken{}, init...)} +} + +type oidcRefreshTokensPersister struct { + oidcRefreshTokens []models.RefreshToken +} + +func (o *oidcRefreshTokensPersister) Get(ctx context.Context, uuid uuid.UUID) (*models.RefreshToken, error) { + var found *models.RefreshToken + + for _, data := range o.oidcRefreshTokens { + if data.ID == uuid { + d := data + found = &d + } + } + + return found, nil +} + +func (o *oidcRefreshTokensPersister) Create(ctx context.Context, refreshToken models.RefreshToken) error { + o.oidcRefreshTokens = append(o.oidcRefreshTokens, refreshToken) + + return nil +} + +func (o *oidcRefreshTokensPersister) Delete(ctx context.Context, refreshToken models.RefreshToken) error { + index := -1 + for i, data := range o.oidcRefreshTokens { + if data.ID == refreshToken.ID { + index = i + } + } + + if index > -1 { + o.oidcRefreshTokens = append(o.oidcRefreshTokens[:index], o.oidcRefreshTokens[index+1:]...) + } + + return nil +} + +func (o *oidcRefreshTokensPersister) TerminateSessions(ctx context.Context, clientID string, userID string) error { + for _, data := range o.oidcRefreshTokens { + if data.ClientID == clientID && data.UserID == userID { + _ = o.Delete(ctx, data) + } + } + + return nil +} diff --git a/backend/test/persister.go b/backend/test/persister.go index 4731170b5..89046f255 100644 --- a/backend/test/persister.go +++ b/backend/test/persister.go @@ -2,11 +2,12 @@ package test import ( "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" "github.com/teamhanko/hanko/backend/persistence" "github.com/teamhanko/hanko/backend/persistence/models" ) -func NewPersister(user []models.User, passcodes []models.Passcode, jwks []models.Jwk, credentials []models.WebauthnCredential, sessionData []models.WebauthnSessionData, passwords []models.PasswordCredential, auditLogs []models.AuditLog, emails []models.Email, primaryEmails []models.PrimaryEmail, identities []models.Identity, tokens []models.Token) persistence.Persister { +func NewPersister(user []models.User, passcodes []models.Passcode, jwks []models.Jwk, credentials []models.WebauthnCredential, sessionData []models.WebauthnSessionData, passwords []models.PasswordCredential, auditLogs []models.AuditLog, emails []models.Email, primaryEmails []models.PrimaryEmail, identities []models.Identity, tokens []models.Token, accessTokens []models.AccessToken, refreshTokens []models.RefreshToken, keys []models.Key, authRequests []models.AuthRequest, codes map[string]uuid.UUID) persistence.Persister { return &persister{ userPersister: NewUserPersister(user), passcodePersister: NewPasscodePersister(passcodes), @@ -19,6 +20,10 @@ func NewPersister(user []models.User, passcodes []models.Passcode, jwks []models primaryEmailPersister: NewPrimaryEmailPersister(primaryEmails), identityPersister: NewIdentityPersister(identities), tokenPersister: NewTokenPersister(tokens), + oidcAccessTokensPersister: NewOidcAccessTokensPersister(accessTokens), + oidcRefreshTokensPersister: NewOidcRefreshTokensPersister(refreshTokens), + oidcKeysPersister: NewOidcKeysPersister(keys), + oidcAuthRequestsPersister: NewOidcAuthRequestsPersister(authRequests, codes), } } @@ -34,6 +39,10 @@ type persister struct { primaryEmailPersister persistence.PrimaryEmailPersister identityPersister persistence.IdentityPersister tokenPersister persistence.TokenPersister + oidcAccessTokensPersister persistence.OIDCAccessTokenPersister + oidcRefreshTokensPersister persistence.OIDCRefreshTokenPersister + oidcKeysPersister persistence.OIDCKeyPersister + oidcAuthRequestsPersister persistence.OIDCAuthRequestPersister } func (p *persister) GetPasswordCredentialPersister() persistence.PasswordCredentialPersister { @@ -132,3 +141,35 @@ func (p *persister) GetTokenPersister() persistence.TokenPersister { func (p *persister) GetTokenPersisterWithConnection(tx *pop.Connection) persistence.TokenPersister { return p.tokenPersister } + +func (p *persister) GetOIDCAccessTokenPersister() persistence.OIDCAccessTokenPersister { + return p.oidcAccessTokensPersister +} + +func (p *persister) GetOIDCAccessTokenPersisterWithConnection(tx *pop.Connection) persistence.OIDCAccessTokenPersister { + return p.oidcAccessTokensPersister +} + +func (p *persister) GetOIDCRefreshTokenPersister() persistence.OIDCRefreshTokenPersister { + return p.oidcRefreshTokensPersister +} + +func (p *persister) GetOIDCRefreshTokenPersisterWithConnection(tx *pop.Connection) persistence.OIDCRefreshTokenPersister { + return p.oidcRefreshTokensPersister +} + +func (p *persister) GetOIDCKeyPersister() persistence.OIDCKeyPersister { + return p.oidcKeysPersister +} + +func (p *persister) GetOIDCKeyPersisterWithConnection(tx *pop.Connection) persistence.OIDCKeyPersister { + return p.oidcKeysPersister +} + +func (p *persister) GetOIDCAuthRequestPersister() persistence.OIDCAuthRequestPersister { + return p.oidcAuthRequestsPersister +} + +func (p *persister) GetOIDCAuthRequestPersisterWithConnection(tx *pop.Connection) persistence.OIDCAuthRequestPersister { + return p.oidcAuthRequestsPersister +}