Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow multiple JWT tokens to be configured (closes #108) #109

Merged
merged 14 commits into from
Jan 11, 2024
Merged
32 changes: 26 additions & 6 deletions pkg/sidecar/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,22 @@ type Config struct {
RenewSignalDeprecated string `hcl:"renewSignal"`

// JWT configuration
JWTAudience string `hcl:"jwt_audience"`
JWTSvidFilename string `hcl:"jwt_svid_file_name"`
JWTBundleFilename string `hcl:"jwt_bundle_file_name"`
JwtSvids []JwtConfig `hcl:"jwt_svids"`
JWTAudienceDeprecated string `hcl:"jwt_audience"`
JWTSvidFilenameDeprecated string `hcl:"jwt_svid_file_name"`
keeganwitt marked this conversation as resolved.
Show resolved Hide resolved
JWTBundleFilename string `hcl:"jwt_bundle_file_name"`

// TODO: is there a reason for this to be exposed? and inside of config?
ReloadExternalProcess func() error
// TODO: is there a reason for this to be exposed? and inside of config?
Log logrus.FieldLogger
}

type JwtConfig struct {
JWTAudience string `hcl:"jwt_audience"`
JWTSvidFilename string `hcl:"jwt_svid_file_name"`
}

// ParseConfig parses the given HCL file into a SidecarConfig struct
func ParseConfig(file string) (*Config, error) {
sidecarConfig := new(Config)
Expand Down Expand Up @@ -120,11 +126,17 @@ func ValidateConfig(c *Config) error {
c.RenewSignal = c.RenewSignalDeprecated
}

for _, jwtConfig := range c.JwtSvids {
if countEmpty(jwtConfig.JWTSvidFilename, jwtConfig.JWTAudience) > 0 {
return errors.New("both 'jwt_file_name' and 'jwt_audience' are required in 'jwt_svids'")
keeganwitt marked this conversation as resolved.
Show resolved Hide resolved
}
}

x509EmptyCount := countEmpty(c.SvidFileName, c.SvidBundleFileName, c.SvidKeyFileName)
jwtSVIDEmptyCount := countEmpty(c.JWTSvidFilename, c.JWTAudience)
jwtSVIDEmptyCount := countEmpty(c.JWTSvidFilenameDeprecated, c.JWTAudienceDeprecated)
jwtBundleEmptyCount := countEmpty(c.SvidBundleFileName)
if x509EmptyCount == 3 && jwtSVIDEmptyCount == 2 && jwtBundleEmptyCount == 1 {
return errors.New("at least one of the sets ('svid_file_name', 'svid_key_file_name', 'svid_bundle_file_name'), ('jwt_file_name', 'jwt_audience'), or ('jwt_bundle_file_name') must be fully specified")
if x509EmptyCount == 3 && jwtSVIDEmptyCount == 2 && c.JwtSvids == nil && jwtBundleEmptyCount == 1 {
return errors.New("at least one of the sets ('svid_file_name', 'svid_key_file_name', 'svid_bundle_file_name'), ('jwt_file_name', 'jwt_audience'), 'jwt_svids', or ('jwt_bundle_file_name') must be fully specified")
keeganwitt marked this conversation as resolved.
Show resolved Hide resolved
}

if x509EmptyCount != 0 && x509EmptyCount != 3 {
Expand All @@ -135,6 +147,14 @@ func ValidateConfig(c *Config) error {
return errors.New("all or none of 'jwt_file_name', 'jwt_audience' must be specified")
}

if jwtSVIDEmptyCount == 0 {
c.Log.Warn(getWarning("jwt_file_name and jwt_audience", "jwt_svids"))
}

if jwtSVIDEmptyCount != 0 && c.JwtSvids == nil {
return errors.New("must not specify deprecated JWT configs ('jwt_file_name' and 'jwt_audience') and new JWT config ('jwt_svids')")
}

return nil
}

Expand Down
33 changes: 25 additions & 8 deletions pkg/sidecar/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ func TestParseConfig(t *testing.T) {
assert.Equal(t, expectedSvidFileName, c.SvidFileName)
assert.Equal(t, expectedKeyFileName, c.SvidKeyFileName)
assert.Equal(t, expectedSvidBundleFileName, c.SvidBundleFileName)
assert.Equal(t, expectedJWTSVIDFileName, c.JWTSvidFilename)
assert.Equal(t, expectedJWTSVIDFileName, c.JWTSvidFilenameDeprecated)
assert.Equal(t, expectedJWTBundleFileName, c.JWTBundleFilename)
assert.Equal(t, expectedJWTAudience, c.JWTAudience)
assert.Equal(t, expectedJWTAudience, c.JWTAudienceDeprecated)
assert.True(t, c.AddIntermediatesToBundle)
}

Expand All @@ -56,12 +56,29 @@ func TestValidateConfig(t *testing.T) {
SvidBundleFileName: "bundle.pem",
},
},
{
name: "warns on deprecated jwt configs",
config: &Config{
AgentAddress: "path",
JWTAudienceDeprecated: "your-audience",
JWTSvidFilenameDeprecated: "jwt.token",
JWTBundleFilename: "bundle.json",
},
expectLogs: []shortEntry{
{
Level: logrus.WarnLevel,
Message: "jwt_file_name and jwt_audience will be deprecated, should be used as jwt_svids",
},
},
},
{
name: "no error",
config: &Config{
AgentAddress: "path",
JWTAudience: "your-audience",
JWTSvidFilename: "jwt.token",
AgentAddress: "path",
JwtSvids: []JwtConfig{{
JWTSvidFilename: "jwt.token",
JWTAudience: "your-audience",
}},
JWTBundleFilename: "bundle.json",
},
},
Expand All @@ -70,7 +87,7 @@ func TestValidateConfig(t *testing.T) {
config: &Config{
AgentAddress: "path",
},
expectError: "at least one of the sets ('svid_file_name', 'svid_key_file_name', 'svid_bundle_file_name'), ('jwt_file_name', 'jwt_audience'), or ('jwt_bundle_file_name') must be fully specified",
expectError: "at least one of the sets ('svid_file_name', 'svid_key_file_name', 'svid_bundle_file_name'), ('jwt_file_name', 'jwt_audience'), 'jwt_svids', or ('jwt_bundle_file_name') must be fully specified",
},
{
name: "missing svid config",
Expand All @@ -83,8 +100,8 @@ func TestValidateConfig(t *testing.T) {
{
name: "missing jwt config",
config: &Config{
AgentAddress: "path",
JWTSvidFilename: "cert.pem",
AgentAddress: "path",
JWTSvidFilenameDeprecated: "cert.pem",
},
expectError: "all or none of 'jwt_file_name', 'jwt_audience' must be specified",
},
Expand Down
41 changes: 26 additions & 15 deletions pkg/sidecar/sidecar.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,19 +103,30 @@ func (s *Sidecar) RunDaemon(ctx context.Context) error {
}()
}

if s.config.JWTSvidFilename != "" && s.config.JWTAudience != "" {
if s.config.JWTSvidFilenameDeprecated != "" && s.config.JWTAudienceDeprecated != "" {
jwtSource, err := workloadapi.NewJWTSource(ctx, workloadapi.WithClientOptions(s.getWorkloadAPIAdress()))
if err != nil {
s.config.Log.Fatalf("Error watching JWT svid updates: %v", err)
}
s.jwtSource = jwtSource
defer s.jwtSource.Close()

wg.Add(1)
go func() {
defer wg.Done()
s.updateJWTSVID(ctx)
}()
if s.config.JwtSvids != nil {
for _, jwtConfig := range s.config.JwtSvids {
jwtConfig := jwtConfig
wg.Add(1)
go func() {
defer wg.Done()
s.updateJWTSVID(ctx, jwtConfig.JWTAudience, jwtConfig.JWTSvidFilename)
}()
}
} else {
wg.Add(1)
go func() {
defer wg.Done()
s.updateJWTSVID(ctx, s.config.JWTAudienceDeprecated, s.config.JWTSvidFilenameDeprecated)
}()
}
}

wg.Wait()
Expand Down Expand Up @@ -274,14 +285,14 @@ func (s *Sidecar) updateJWTBundle(jwkSet *jwtbundle.Set) {
}
}

func (s *Sidecar) fetchJWTSVID(ctx context.Context) (*jwtsvid.SVID, error) {
jwtSVID, err := s.jwtSource.FetchJWTSVID(ctx, jwtsvid.Params{Audience: s.config.JWTAudience})
func (s *Sidecar) fetchJWTSVIDs(ctx context.Context, jwtAudience string) (*jwtsvid.SVID, error) {
jwtSVID, err := s.jwtSource.FetchJWTSVID(ctx, jwtsvid.Params{Audience: jwtAudience})
if err != nil {
s.config.Log.Errorf("Unable to fetch JWT SVID: %v", err)
return nil, err
}

_, err = jwtsvid.ParseAndValidate(jwtSVID.Marshal(), s.jwtSource, []string{s.config.JWTAudience})
_, err = jwtsvid.ParseAndValidate(jwtSVID.Marshal(), s.jwtSource, []string{jwtAudience})
if err != nil {
s.config.Log.Errorf("Unable to parse or validate token: %v", err)
return nil, err
Expand Down Expand Up @@ -312,16 +323,16 @@ func getRefreshInterval(svid *jwtsvid.SVID) time.Duration {
return time.Until(svid.Expiry)/2 + time.Second
}

func (s *Sidecar) performJWTSVIDUpdate(ctx context.Context) (*jwtsvid.SVID, error) {
func (s *Sidecar) performJWTSVIDUpdate(ctx context.Context, jwtAudience string, jwtSvidFilename string) (*jwtsvid.SVID, error) {
s.config.Log.Debug("Updating JWT SVID")

jwtSVID, err := s.fetchJWTSVID(ctx)
jwtSVID, err := s.fetchJWTSVIDs(ctx, jwtAudience)
if err != nil {
s.config.Log.Errorf("Unable to update JWT SVID: %v", err)
return nil, err
}

filePath := path.Join(s.config.CertDir, s.config.JWTSvidFilename)
filePath := path.Join(s.config.CertDir, jwtSvidFilename)
if err = os.WriteFile(filePath, []byte(jwtSVID.Marshal()), os.ModePerm); err != nil {
s.config.Log.Errorf("Unable to update JWT SVID: %v", err)
return nil, err
Expand All @@ -331,10 +342,10 @@ func (s *Sidecar) performJWTSVIDUpdate(ctx context.Context) (*jwtsvid.SVID, erro
return jwtSVID, nil
}

func (s *Sidecar) updateJWTSVID(ctx context.Context) {
func (s *Sidecar) updateJWTSVID(ctx context.Context, jwtAudience string, jwtSvidFilename string) {
retryInterval := createRetryIntervalFunc()
var initialInterval time.Duration
jwtSVID, err := s.performJWTSVIDUpdate(ctx)
jwtSVID, err := s.performJWTSVIDUpdate(ctx, jwtAudience, jwtSvidFilename)
if err != nil {
// If the first update fails, use the retry interval
initialInterval = retryInterval()
Expand All @@ -350,7 +361,7 @@ func (s *Sidecar) updateJWTSVID(ctx context.Context) {
case <-ctx.Done():
return
case <-ticker.C:
jwtSVID, err = s.performJWTSVIDUpdate(ctx)
jwtSVID, err = s.performJWTSVIDUpdate(ctx, jwtAudience, jwtSvidFilename)
if err == nil {
retryInterval = createRetryIntervalFunc()
ticker.Reset(getRefreshInterval(jwtSVID))
Expand Down
Loading