From 997b5d3e5573b3274129709d6b6acca8ae22c05d Mon Sep 17 00:00:00 2001 From: Florian Loch Date: Mon, 4 Mar 2024 23:01:54 +0100 Subject: [PATCH 1/2] refactor: remove unused constant --- csrf.go | 1 - 1 file changed, 1 deletion(-) diff --git a/csrf.go b/csrf.go index 97a3925..56a05eb 100644 --- a/csrf.go +++ b/csrf.go @@ -19,7 +19,6 @@ const ( errorKey string = "gorilla.csrf.Error" skipCheckKey string = "gorilla.csrf.Skip" cookieName string = "_gorilla_csrf" - errorPrefix string = "gorilla/csrf: " ) var ( From b6bb886ee7a9334e66c05e356fb32937f6c24c03 Mon Sep 17 00:00:00 2001 From: Florian Loch Date: Mon, 4 Mar 2024 23:10:18 +0100 Subject: [PATCH 2/2] feat: allow to get the token straight from context.Context without having to provide a http.Request --- context.go | 6 +++++- helpers.go | 10 +++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/context.go b/context.go index d24b146..f3abab5 100644 --- a/context.go +++ b/context.go @@ -10,7 +10,11 @@ import ( ) func contextGet(r *http.Request, key string) (interface{}, error) { - val := r.Context().Value(key) + return valueFromContext(r.Context(), key) +} + +func valueFromContext(ctx context.Context, key string) (interface{}, error) { + val := ctx.Value(key) if val == nil { return nil, fmt.Errorf("no value exists in the context for key %q", key) } diff --git a/helpers.go b/helpers.go index 99005ee..f0495fe 100644 --- a/helpers.go +++ b/helpers.go @@ -1,6 +1,7 @@ package csrf import ( + "context" "crypto/rand" "crypto/subtle" "encoding/base64" @@ -14,7 +15,14 @@ import ( // a JSON response body. An empty token will be returned if the middleware // has not been applied (which will fail subsequent validation). func Token(r *http.Request) string { - if val, err := contextGet(r, tokenKey); err == nil { + return TokenFromContext(r.Context()) +} + +// TokenFromContext returns a masked CSRF token ready for passing into HTML template or +// a JSON response body. An empty token will be returned if the middleware +// has not been applied (which will fail subsequent validation). +func TokenFromContext(ctx context.Context) string { + if val, err := valueFromContext(ctx, tokenKey); err == nil { if maskedToken, ok := val.(string); ok { return maskedToken }