Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PSK support for v2
Browse files Browse the repository at this point in the history
nbrownus committed Nov 13, 2024
1 parent 5380fef commit b07d245
Showing 9 changed files with 459 additions and 33 deletions.
15 changes: 7 additions & 8 deletions connection_state.go
Original file line number Diff line number Diff line change
@@ -27,7 +27,7 @@ type ConnectionState struct {
writeLock sync.Mutex
}

func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern, psk []byte) (*ConnectionState, error) {
var dhFunc noise.DHFunc
switch crt.Curve() {
case cert.Curve_CURVE25519:
@@ -56,13 +56,12 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i
b.Update(l, 0)

hs, err := noise.NewHandshakeState(noise.Config{
CipherSuite: ncs,
Random: rand.Reader,
Pattern: pattern,
Initiator: initiator,
StaticKeypair: static,
//NOTE: These should come from CertState (pki.go) when we finally implement it
PresharedKey: []byte{},
CipherSuite: ncs,
Random: rand.Reader,
Pattern: pattern,
Initiator: initiator,
StaticKeypair: static,
PresharedKey: psk,
PresharedKeyPlacement: 0,
})
if err != nil {
132 changes: 132 additions & 0 deletions e2e/handshakes_test.go
Original file line number Diff line number Diff line change
@@ -1105,6 +1105,138 @@ func TestV2NonPrimaryWithLighthouse(t *testing.T) {
theirControl.Stop()
}

func TestPSK(t *testing.T) {
tests := []struct {
name string
myPskMode nebula.PskMode
theirPskMode nebula.PskMode
}{
// All accepting
{
name: "both accepting",
myPskMode: nebula.PskAccepting,
theirPskMode: nebula.PskAccepting,
},

// accepting and sending both ways
{
name: "accepting to sending",
myPskMode: nebula.PskAccepting,
theirPskMode: nebula.PskSending,
},
{
name: "sending to accepting",
myPskMode: nebula.PskSending,
theirPskMode: nebula.PskAccepting,
},

// All sending
{
name: "sending to sending",
myPskMode: nebula.PskSending,
theirPskMode: nebula.PskSending,
},

// enforced and sending both ways
{
name: "enforced to sending",
myPskMode: nebula.PskEnforced,
theirPskMode: nebula.PskSending,
},
{
name: "sending to enforced",
myPskMode: nebula.PskSending,
theirPskMode: nebula.PskEnforced,
},

// All enforced
{
name: "both enforced",
myPskMode: nebula.PskEnforced,
theirPskMode: nebula.PskEnforced,
},

// Enforced can technically handshake with an accepting node, but it is bad to be in this state
{
name: "enforced to accepting",
myPskMode: nebula.PskEnforced,
theirPskMode: nebula.PskAccepting,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var myPskSettings, theirPskSettings m

switch test.myPskMode {
case nebula.PskAccepting:
myPskSettings = m{"psk": &m{"mode": "accepting", "keys": []string{"garbage0", "this is a key"}}}
case nebula.PskSending:
myPskSettings = m{"psk": &m{"mode": "sending", "keys": []string{"this is a key", "garbage1"}}}
case nebula.PskEnforced:
myPskSettings = m{"psk": &m{"mode": "enforced", "keys": []string{"this is a key", "garbage2"}}}
}

switch test.theirPskMode {
case nebula.PskAccepting:
theirPskSettings = m{"psk": &m{"mode": "accepting", "keys": []string{"garbage3", "this is a key"}}}
case nebula.PskSending:
theirPskSettings = m{"psk": &m{"mode": "sending", "keys": []string{"this is a key", "garbage4"}}}
case nebula.PskEnforced:
theirPskSettings = m{"psk": &m{"mode": "enforced", "keys": []string{"this is a key", "garbage5"}}}
}

ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
myControl, myVpnIp, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me", "10.0.0.1/24", myPskSettings)
theirControl, theirVpnIp, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.0.0.2/24", theirPskSettings)

myControl.InjectLightHouseAddr(theirVpnIp[0].Addr(), theirUdpAddr)
r := router.NewR(t, myControl, theirControl)

// Start the servers
myControl.Start()
theirControl.Start()

t.Log("Route until we see our cached packet flow")
myControl.InjectTunUDPPacket(theirVpnIp[0].Addr(), 80, myVpnIp[0].Addr(), 80, []byte("Hi from me"))
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
h := &header.H{}
err := h.Parse(p.Data)
if err != nil {
panic(err)
}

// If this is the stage 1 handshake packet and I am configured to send with a psk, my cert name should
// not appear. It would likely be more obvious to unmarshal the payload and check but this works fine for now
if test.myPskMode == nebula.PskEnforced || test.myPskMode == nebula.PskSending {
if h.Type == 0 && h.MessageCounter == 1 {
assert.NotContains(t, string(p.Data), "test me")
}
}

if p.To == theirUdpAddr && h.Type == 1 {
return router.RouteAndExit
}

return router.KeepRouting
})

t.Log("My cached packet should be received by them")
myCachedPacket := theirControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp[0].Addr(), theirVpnIp[0].Addr(), 80, 80)

t.Log("Test the tunnel with them")
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl)
assertTunnel(t, myVpnIp[0].Addr(), theirVpnIp[0].Addr(), myControl, theirControl, r)

myControl.Stop()
theirControl.Stop()
//TODO: assert hostmaps
})
}

}

//TODO: test
// Race winner renews and handshakes
// Race loser renews and handshakes
7 changes: 3 additions & 4 deletions e2e/router/router.go
Original file line number Diff line number Diff line change
@@ -111,10 +111,6 @@ type ExitFunc func(packet *udp.Packet, receiver *nebula.Control) ExitType
func NewR(t testing.TB, controls ...*nebula.Control) *R {
ctx, cancel := context.WithCancel(context.Background())

if err := os.MkdirAll("mermaid", 0755); err != nil {
panic(err)
}

r := &R{
controls: make(map[netip.AddrPort]*nebula.Control),
vpnControls: make(map[netip.Addr]*nebula.Control),
@@ -194,6 +190,9 @@ func (r *R) renderFlow() {
return
}

if err := os.MkdirAll(filepath.Dir(r.fn), 0755); err != nil {
panic(err)
}
f, err := os.OpenFile(r.fn, os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0644)
if err != nil {
panic(err)
33 changes: 32 additions & 1 deletion examples/config.yml
Original file line number Diff line number Diff line change
@@ -19,6 +19,38 @@ pki:
# After all hosts in the mesh are using a v2 certificate then v1 certificates are no longer needed.
# default_version: 1

# psk can be used to mask the contents of handshakes.
psk:
# `mode` defines how the pre shared keys can be used in a handshake.
# `accepting` (the default) will initiate handshakes using an empty key and will try to use any keys provided when
# receiving handshakes, including an empty key.
# `sending` will initiate handshakes with the first key provided and will try to use any keys provided when
# receiving handshakes, including an empty key.
# `enforced` will initiate handshakes with the first psk key provided and will try to use any keys provided when
# responding to handshakes. An empty key will not be allowed.
#
# To change a mesh from not using a psk to enforcing psk:
# 1. Leave `mode` as `accepting` and configure `psk.keys` to match on all nodes in the mesh and reload.
# 2. Change `mode` to `sending` on all nodes in the mesh and reload.
# 3. Change `mode` to `enforced` on all nodes in the mesh and reload.
#mode: accepting

# The keys provided are sent through hkdf to ensure the shared secret used in the noise protocol is the
# correct byte length.
#
# Only the first key is used for outbound handshakes but all keys provided will be tried in the order specified, on
# incoming handshakes. This is to allow for psk rotation.
#
# To rotate a primary key:
# 1. Put the new key in the 2nd slot on every node in the mesh and reload.
# 2. Move the key from the 2nd slot to the 1st slot, the old primary key is now in the 2nd slot, reload.
# 3. Remove the old primary key once it is no longer in use on every node in the mesh and reload.
#keys:
# - shared secret string, this one is used in all outbound handshakes # This is the primary key used when sending handshakes
# - this is a fallback key, received handshakes can use this
# - another fallback, received handshakes can use this one too
# - "\x68\x65\x6c\x6c\x6f\x20\x66\x72\x69\x65\x6e\x64\x73" # for raw bytes if you desire

# The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
# A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.
# The syntax is:
@@ -309,7 +341,6 @@ logging:
# after receiving the response for lighthouse queries
#trigger_buffer: 64


# Nebula security group configuration
firewall:
# Action to take when a packet is not allowed by the firewall rules.
59 changes: 39 additions & 20 deletions handshake_ix.go
Original file line number Diff line number Diff line change
@@ -50,7 +50,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
Error("Unable to handshake with host because no certificate handshake bytes is available")
}

ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX, cs.psk.primary)
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
@@ -104,34 +104,53 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
Error("Unable to handshake with host because no certificate is available")
}

ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to create connection state")
return
}
var (
err error
ci *ConnectionState
msg []byte
)

// Mark packet 1 as seen so it doesn't show up as missed
ci.window.Update(f.l, 1)
hs := &NebulaHandshake{}

msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to call noise.ReadMessage")
return
for _, psk := range cs.psk.keys {
ci, err = NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX, psk)
if err != nil {
//TODO: should be bother logging this, if we have multiple psks and the error is unrelated it will be verbose.
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to create connection state")
continue
}

msg, _, _, err = ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil {
// Calls to ReadMessage with an incorrect psk should fail, try the next one if we have one
continue
}

// Sometimes ReadMessage returns fine with a nil psk even if the handshake is using a psk, ensure our protobuf
// comes out clean as well
err = hs.Unmarshal(msg)
if err == nil {
// There was no error, we can continue with this handshake
break
}

// The unmarshal failed, try the next psk if we have one
}

hs := &NebulaHandshake{}
err = hs.Unmarshal(msg)
// We finished with an error, log it and get out
if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("udpAddr", addr).
// We aren't logging the error here because we can't be sure of the failure when using psk
f.l.WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed unmarshal handshake message")
Error("Was unable to decrypt the handshake")
return
}

// Mark packet 1 as seen so it doesn't show up as missed
ci.window.Update(f.l, 1)

remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool())
if err != nil {
e := f.l.WithError(err).WithField("udpAddr", addr).
5 changes: 5 additions & 0 deletions handshake_manager_test.go
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@ import (
"github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func Test_NewHandshakeManagerVpnIp(t *testing.T) {
@@ -23,11 +24,15 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {

lh := newTestLighthouse()

psk, err := NewPsk(PskAccepting, nil)
require.NoError(t, err)

cs := &CertState{
defaultVersion: cert.Version1,
privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1},
v1HandshakeBytes: []byte{},
psk: psk,
}

blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
20 changes: 20 additions & 0 deletions pki.go
Original file line number Diff line number Diff line change
@@ -38,6 +38,8 @@ type CertState struct {
pkcs11Backed bool
cipher string

psk *Psk

myVpnNetworks []netip.Prefix
myVpnNetworksTable *bart.Table[struct{}]
myVpnAddrs []netip.Addr
@@ -97,6 +99,14 @@ func (p *PKI) reload(c *config.C, initial bool) error {
err.Log(p.l)
}

err = p.reloadCAPool(c)
if err != nil {
if initial {
return err
}
err.Log(p.l)
}

return nil
}

@@ -181,6 +191,16 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
}
}

psk, err := NewPskFromConfig(c)
if err != nil {
return util.NewContextualError("Failed to load psk from config", nil, err)
}
if len(psk.keys) > 0 {
p.l.WithField("pskMode", psk.mode).WithField("keysLen", len(psk.keys)).
Info("pre shared keys are in use")
}
newState.psk = psk

p.cs.Store(newState)

//TODO: newState needs a stringer that does json
150 changes: 150 additions & 0 deletions psk.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package nebula

import (
"crypto/sha256"
"errors"
"fmt"
"io"

"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util"
"golang.org/x/crypto/hkdf"
)

var ErrNotAPskMode = errors.New("not a psk mode")
var ErrKeyTooShort = errors.New("key is too short")
var ErrNotEnoughPskKeys = errors.New("at least 1 key is required")

// MinPskLength is the minimum bytes that we accept for a user defined psk, the choice is arbitrary
const MinPskLength = 8

type PskMode int

const (
PskAccepting PskMode = 0
PskSending PskMode = 1
PskEnforced PskMode = 2
)

func NewPskMode(m string) (PskMode, error) {
switch m {
case "accepting":
return PskAccepting, nil
case "sending":
return PskSending, nil
case "enforced":
return PskEnforced, nil
}
return PskAccepting, ErrNotAPskMode
}

func (p PskMode) String() string {
switch p {
case PskAccepting:
return "accepting"
case PskSending:
return "sending"
case PskEnforced:
return "enforced"
}

return "unknown"
}

func (p PskMode) IsValid() bool {
switch p {
case PskAccepting, PskSending, PskEnforced:
return true
default:
return false
}
}

type Psk struct {
// pskMode sets how psk works, ignored, allowed for incoming, or enforced for all
mode PskMode

// primary is the key to use when sending, it may be nil
primary []byte

// keys holds all pre-computed psk hkdfs
// Handshakes iterate this directly
keys [][]byte
}

// NewPskFromConfig is a helper for initial boot and config reloading.
func NewPskFromConfig(c *config.C) (*Psk, error) {
sMode := c.GetString("psk.mode", "accepting")
mode, err := NewPskMode(sMode)
if err != nil {
return nil, util.NewContextualError("Could not parse psk.mode", m{"mode": mode}, err)
}

return NewPsk(
mode,
c.GetStringSlice("psk.keys", nil),
)
}

// NewPsk creates a new Psk object and handles the caching of all accepted keys
func NewPsk(mode PskMode, keys []string) (*Psk, error) {
if !mode.IsValid() {
return nil, ErrNotAPskMode
}

psk := &Psk{
mode: mode,
}

err := psk.cachePsks(keys)
if err != nil {
return nil, err
}

return psk, nil
}

// cachePsks generates all psks we accept and caches them to speed up handshaking
func (p *Psk) cachePsks(keys []string) error {
if p.mode != PskAccepting && len(keys) < 1 {
return ErrNotEnoughPskKeys
}

p.keys = [][]byte{}

for i, rk := range keys {
k, err := sha256KdfFromString(rk)
if err != nil {
return fmt.Errorf("failed to generate key for position %v: %w", i, err)
}

p.keys = append(p.keys, k)
}

if p.mode != PskAccepting {
// We are either sending or enforcing, the primary key must the first slot
p.primary = p.keys[0]
}

if p.mode != PskEnforced {
// If we are not enforcing psk use then a nil psk is allowed
p.keys = append(p.keys, nil)
}

return nil
}

// sha256KdfFromString generates a useful key to use from a provided secret
func sha256KdfFromString(secret string) ([]byte, error) {
if len(secret) < MinPskLength {
return nil, ErrKeyTooShort
}

hmacKey := make([]byte, sha256.Size)
_, err := io.ReadFull(hkdf.New(sha256.New, []byte(secret), nil, nil), hmacKey)
if err != nil {
return nil, err
}

return hmacKey, nil
}
71 changes: 71 additions & 0 deletions psk_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package nebula

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestNewPsk(t *testing.T) {
t.Run("mode accepting", func(t *testing.T) {
p, err := NewPsk(PskAccepting, nil)
assert.NoError(t, err)
assert.Equal(t, PskAccepting, p.mode)
assert.Nil(t, p.keys[0])
assert.Nil(t, p.primary)

p, err = NewPsk(PskAccepting, []string{"1234567"})
assert.Error(t, ErrKeyTooShort)

p, err = NewPsk(PskAccepting, []string{"hi there friends"})
assert.NoError(t, err)
assert.Equal(t, PskAccepting, p.mode)
assert.Nil(t, p.primary)
assert.Len(t, p.keys, 2)
assert.Nil(t, p.keys[1])

expectedCache := []byte{
0xb9, 0x8c, 0xdc, 0xac, 0x77, 0xf4, 0x8c, 0xf8, 0x1d, 0xe7, 0xe7, 0xb, 0x53, 0x25, 0xd3, 0x65,
0xa3, 0x9f, 0x78, 0xb2, 0xc7, 0x2d, 0xa5, 0xd8, 0x84, 0x81, 0x7b, 0xb5, 0xdb, 0xe0, 0x9a, 0xef,
}
assert.Equal(t, expectedCache, p.keys[0])
})

t.Run("mode sending", func(t *testing.T) {
p, err := NewPsk(PskSending, nil)
assert.Error(t, ErrNotEnoughPskKeys, err)

p, err = NewPsk(PskSending, []string{"1234567"})
assert.Error(t, ErrKeyTooShort)

p, err = NewPsk(PskSending, []string{"hi there friends"})
assert.NoError(t, err)
assert.Equal(t, PskSending, p.mode)
assert.Len(t, p.keys, 2)
assert.Nil(t, p.keys[1])

expectedCache := []byte{
0xb9, 0x8c, 0xdc, 0xac, 0x77, 0xf4, 0x8c, 0xf8, 0x1d, 0xe7, 0xe7, 0xb, 0x53, 0x25, 0xd3, 0x65,
0xa3, 0x9f, 0x78, 0xb2, 0xc7, 0x2d, 0xa5, 0xd8, 0x84, 0x81, 0x7b, 0xb5, 0xdb, 0xe0, 0x9a, 0xef,
}
assert.Equal(t, expectedCache, p.keys[0])
assert.Equal(t, p.keys[0], p.primary)
})

t.Run("mode enforced", func(t *testing.T) {
p, err := NewPsk(PskEnforced, nil)
assert.Error(t, ErrNotEnoughPskKeys, err)

p, err = NewPsk(PskEnforced, []string{"hi there friends"})
assert.NoError(t, err)
assert.Equal(t, PskEnforced, p.mode)
assert.Len(t, p.keys, 1)

expectedCache := []byte{
0xb9, 0x8c, 0xdc, 0xac, 0x77, 0xf4, 0x8c, 0xf8, 0x1d, 0xe7, 0xe7, 0xb, 0x53, 0x25, 0xd3, 0x65,
0xa3, 0x9f, 0x78, 0xb2, 0xc7, 0x2d, 0xa5, 0xd8, 0x84, 0x81, 0x7b, 0xb5, 0xdb, 0xe0, 0x9a, 0xef,
}
assert.Equal(t, expectedCache, p.keys[0])
assert.Equal(t, p.keys[0], p.primary)
})
}

0 comments on commit b07d245

Please sign in to comment.