Skip to content

Commit

Permalink
feat: added simple access token caching (#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
pacificcode authored Nov 12, 2024
1 parent 508b3c1 commit c9856bc
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
kind: "\U0001F389 New Product Feature"
body: added simple caching for access tokens
time: 2024-11-11T10:19:38.406242-08:00
52 changes: 49 additions & 3 deletions vault/vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ import (
"fmt"
"io"
"log"
"math"
"net/http"
"os"
"strings"
"time"

Expand Down Expand Up @@ -58,6 +60,12 @@ type Vault struct {
Configuration
}

//nolint:tagliatelle // the json is coming from an external API call
type TokenCache struct {
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
}

// New returns a Vault or an error if the Configuration is invalid
func New(config Configuration) (*Vault, error) {
if config.Provider == auth.CLIENT {
Expand Down Expand Up @@ -138,12 +146,48 @@ type accessTokenRequest struct {
AwsHeaders string `json:"aws_headers,omitempty"`
}

//nolint:tagliatelle // the json is coming from an external API call
type accessTokenResponse struct {
AccessToken string `json:"accessToken"`
ExpiresIn int `json:"expiresIn"`
}

func (v Vault) setCacheAccessToken(value string, expiresIn int) bool {
percentage := 0.9
cache := TokenCache{}
cache.AccessToken = value
cache.ExpiresIn = (int(time.Now().Unix()) + expiresIn) - int(math.Floor(float64(expiresIn)*percentage))

data, err := json.Marshal(cache)
if err != nil {
return false
}
os.Setenv("SS_AT", string(data))
return true
}

func (v Vault) getCacheAccessToken() (string, bool) {
data, ok := os.LookupEnv("SS_AT")
if !ok {
os.Setenv("SS_AT", "")
return "", ok
}
cache := TokenCache{}
if err := json.Unmarshal([]byte(data), &cache); err != nil {
return "", false
}
if time.Now().Unix() < int64(cache.ExpiresIn) {
return cache.AccessToken, true
}
return "", false
}

// getAccessToken returns access token fetched from DSV.
func (v Vault) getAccessToken() (string, error) {
accessToken, found := v.getCacheAccessToken()
if found {
return accessToken, nil
}
var rBody accessTokenRequest
switch v.Provider {
case auth.AWS:
Expand All @@ -168,7 +212,6 @@ func (v Vault) getAccessToken() (string, error) {

request, err := json.Marshal(&rBody)
if err != nil {
return "", fmt.Errorf("marshalling token request body: %w", err)
}

url := v.urlFor("token", "")
Expand All @@ -181,9 +224,12 @@ func (v Vault) getAccessToken() (string, error) {
// TODO: cache the token until it expires.
resp := &accessTokenResponse{}
if err = json.Unmarshal(response, &resp); err != nil {
return "", fmt.Errorf("unmarshalling token response: %w", err)
return "", fmt.Errorf("unmarshaling token response: %w", err)
}
ok := v.setCacheAccessToken(resp.AccessToken, resp.ExpiresIn)
if !ok {
return "", fmt.Errorf("unable to cache access token")
}

return resp.AccessToken, nil
}

Expand Down

0 comments on commit c9856bc

Please sign in to comment.