Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[client] Refactor the worker callbacks to channel #2827

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 44 additions & 24 deletions client/internal/peer/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,36 +134,29 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
statusICE: NewAtomicConnStatus(),
}

rFns := WorkerRelayCallbacks{
OnConnReady: conn.relayConnectionIsReady,
OnDisconnected: conn.onWorkerRelayStateDisconnected,
}

wFns := WorkerICECallbacks{
OnConnReady: conn.iCEConnectionIsReady,
OnStatusChanged: conn.onWorkerICEStateDisconnected,
}

ctrl := isController(config)
conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, relayManager, rFns)
conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, relayManager)

relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
conn.workerICE, err = NewWorkerICE(ctx, connLog, config, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally, wFns)
conn.workerICE, err = NewWorkerICE(ctx, connLog, config, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally)
if err != nil {
return nil, err
}

conn.handshaker = NewHandshaker(ctx, connLog, config, signaler, conn.workerICE, conn.workerRelay)

conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) {
conn.workerRelay.OnNewOffer(ctx, remoteOfferAnswer)
})
if os.Getenv("NB_FORCE_RELAY") != "true" {
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) {
conn.workerICE.OnNewOffer(ctx, remoteOfferAnswer)
})
}

conn.guard = guard.NewGuard(connLog, ctrl, conn.isConnectedOnAllWay, config.Timeout, srWatcher)

go conn.handshaker.Listen()

return conn, nil
}

Expand All @@ -190,6 +183,7 @@ func (conn *Conn) Open() {
}

go conn.startHandshakeAndReconnect(conn.ctx)
go conn.listenWorkersEvents()
}

func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) {
Expand Down Expand Up @@ -301,7 +295,7 @@ func (conn *Conn) GetKey() string {
}

// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) {
func (conn *Conn) iCEConnectionIsReady(iceConnInfo ICEConnInfo) {
conn.mu.Lock()
defer conn.mu.Unlock()

Expand All @@ -311,7 +305,7 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon

conn.log.Debugf("ICE connection is ready")

if conn.currentConnPriority > priority {
if conn.currentConnPriority > iceConnInfo.ConnPriority {
conn.statusICE.Set(StatusConnected)
conn.updateIceState(iceConnInfo)
return
Expand All @@ -333,7 +327,7 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
ep = wgProxy.EndpointAddr()
conn.wgProxyICE = wgProxy
} else {
directEp, err := net.ResolveUDPAddr("udp", iceConnInfo.RemoteConn.RemoteAddr().String())
directEp, err := net.ResolveUDPAddr("udp", iceConnInfo.RemoteIceCandidateEndpoint)
if err != nil {
log.Errorf("failed to resolveUDPaddr")
conn.handleConfigurationFailure(err, nil)
Expand Down Expand Up @@ -361,22 +355,21 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
return
}
wgConfigWorkaround()
conn.currentConnPriority = priority
conn.currentConnPriority = iceConnInfo.ConnPriority
conn.statusICE.Set(StatusConnected)
conn.updateIceState(iceConnInfo)
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
}

// todo review to make sense to handle connecting and disconnected status also?
func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
func (conn *Conn) onWorkerICEStateDisconnected() {
conn.mu.Lock()
defer conn.mu.Unlock()

if conn.ctx.Err() != nil {
return
}

conn.log.Tracef("ICE connection state changed to %s", newState)
conn.log.Tracef("ICE connection state changed to disconnected")

if conn.wgProxyICE != nil {
if err := conn.wgProxyICE.CloseConn(); err != nil {
Expand All @@ -396,8 +389,8 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
conn.currentConnPriority = connPriorityRelay
}

changed := conn.statusICE.Get() != newState && newState != StatusConnecting
conn.statusICE.Set(newState)
changed := conn.statusICE.Get() != stateDisconnected
conn.statusICE.Set(stateDisconnected)

conn.guard.SetICEConnDisconnected(changed)

Expand Down Expand Up @@ -731,6 +724,33 @@ func (conn *Conn) logTraceConnState() {
}
}

func (conn *Conn) listenWorkersEvents() {
for {
select {
case e := <-conn.workerRelay.EventChan:
switch e.ConnStatus {
case StatusConnected:
conn.relayConnectionIsReady(e.RelayConnInfo)
case StatusDisconnected:
conn.onWorkerRelayStateDisconnected()
default:
log.Errorf("unexpected relay connection status: %v", e.ConnStatus)
}
case e := <-conn.workerICE.EventChan:
switch e.ConnStatus {
case StatusConnected:
conn.iCEConnectionIsReady(e.ICEConnInfo)
case StatusDisconnected:
conn.onWorkerICEStateDisconnected()
default:
log.Errorf("unexpected ICE connection status: %v", e.ConnStatus)
}
case <-conn.ctx.Done():
return
}
}
}

func isController(config ConnConfig) bool {
return config.LocalKey > config.Key
}
Expand Down
36 changes: 25 additions & 11 deletions client/internal/peer/worker_ice.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,14 @@ import (
"github.com/netbirdio/netbird/route"
)

type ICEEvent struct {
ConnStatus ConnStatus
ICEConnInfo ICEConnInfo
}

type ICEConnInfo struct {
RemoteConn net.Conn
RemoteAddr net.Addr
RosenpassPubKey []byte
RosenpassAddr string
LocalIceCandidateType string
Expand All @@ -29,22 +35,18 @@ type ICEConnInfo struct {
LocalIceCandidateEndpoint string
Relayed bool
RelayedOnLocal bool
}

type WorkerICECallbacks struct {
OnConnReady func(ConnPriority, ICEConnInfo)
OnStatusChanged func(ConnStatus)
ConnPriority ConnPriority
}

type WorkerICE struct {
EventChan chan ICEEvent
ctx context.Context
log *log.Entry
config ConnConfig
signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover
statusRecorder *Status
hasRelayOnLocally bool
conn WorkerICECallbacks

selectedPriority ConnPriority

Expand All @@ -59,16 +61,16 @@ type WorkerICE struct {
localPwd string
}

func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) {
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) {
w := &WorkerICE{
EventChan: make(chan ICEEvent, 2),
ctx: ctx,
log: log,
config: config,
signaler: signaler,
iFaceDiscover: ifaceDiscover,
statusRecorder: statusRecorder,
hasRelayOnLocally: hasRelayOnLocally,
conn: callBacks,
}

localUfrag, localPwd, err := icemaker.GenerateICECredentials()
Expand All @@ -80,7 +82,7 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signal
return w, nil
}

func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
func (w *WorkerICE) OnNewOffer(_ context.Context, remoteOfferAnswer *OfferAnswer) {
w.log.Debugf("OnNewOffer for ICE")
w.muxAgent.Lock()

Expand Down Expand Up @@ -133,6 +135,11 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
return
}

if pair == nil {
w.log.Errorf("remote address is nil, ICE conn already closed")
return
}

if !isRelayCandidate(pair.Local) {
// dynamically set remote WireGuard port if other side specified a different one from the default one
remoteWgPort := iface.DefaultWgPort
Expand All @@ -154,9 +161,13 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
Relayed: isRelayed(pair),
RelayedOnLocal: isRelayCandidate(pair.Local),
ConnPriority: w.selectedPriority,
}
w.log.Debugf("on ICE conn read to use ready")
go w.conn.OnConnReady(w.selectedPriority, ci)
select {
case w.EventChan <- ICEEvent{ConnStatus: StatusConnected, ICEConnInfo: ci}:
case <-w.ctx.Done():
}
}

// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
Expand Down Expand Up @@ -216,7 +227,10 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i
err = agent.OnConnectionStateChange(func(state ice.ConnectionState) {
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
if state == ice.ConnectionStateFailed || state == ice.ConnectionStateDisconnected {
w.conn.OnStatusChanged(StatusDisconnected)
select {
case w.EventChan <- ICEEvent{ConnStatus: StatusDisconnected}:
case <-w.ctx.Done():
}

w.muxAgent.Lock()
agentCancel()
Expand Down
52 changes: 34 additions & 18 deletions client/internal/peer/worker_relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,23 @@ var (
wgHandshakeOvertime = 30 * time.Second
)

type RelayEvent struct {
ConnStatus ConnStatus
RelayConnInfo RelayConnInfo
}

type RelayConnInfo struct {
relayedConn net.Conn
rosenpassPubKey []byte
rosenpassAddr string
}

type WorkerRelayCallbacks struct {
OnConnReady func(RelayConnInfo)
OnDisconnected func()
}

type WorkerRelay struct {
EventChan chan RelayEvent
log *log.Entry
isController bool
config ConnConfig
relayManager relayClient.ManagerService
callBacks WorkerRelayCallbacks

relayedConn net.Conn
relayLock sync.Mutex
Expand All @@ -45,18 +45,18 @@ type WorkerRelay struct {
relaySupportedOnRemotePeer atomic.Bool
}

func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, relayManager relayClient.ManagerService, callbacks WorkerRelayCallbacks) *WorkerRelay {
func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, relayManager relayClient.ManagerService) *WorkerRelay {
r := &WorkerRelay{
EventChan: make(chan RelayEvent, 2),
log: log,
isController: ctrl,
config: config,
relayManager: relayManager,
callBacks: callbacks,
}
return r
}

func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
func (w *WorkerRelay) OnNewOffer(ctx context.Context, remoteOfferAnswer *OfferAnswer) {
if !w.isRelaySupported(remoteOfferAnswer) {
w.log.Infof("Relay is not supported by remote peer")
w.relaySupportedOnRemotePeer.Store(false)
Expand Down Expand Up @@ -87,19 +87,27 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
w.relayedConn = relayedConn
w.relayLock.Unlock()

err = w.relayManager.AddCloseListener(srv, w.onRelayMGDisconnected)
err = w.relayManager.AddCloseListener(srv, func() {
w.onRelayMGDisconnected(ctx)
})
if err != nil {
log.Errorf("failed to add close listener: %s", err)
_ = relayedConn.Close()
return
}

w.log.Debugf("peer conn opened via Relay: %s", srv)
go w.callBacks.OnConnReady(RelayConnInfo{
relayedConn: relayedConn,
rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey,
rosenpassAddr: remoteOfferAnswer.RosenpassAddr,
})
select {
case w.EventChan <- RelayEvent{
ConnStatus: StatusConnected,
RelayConnInfo: RelayConnInfo{
relayedConn: relayedConn,
rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey,
rosenpassAddr: remoteOfferAnswer.RosenpassAddr,
},
}:
case <-ctx.Done():
}
}

func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) {
Expand Down Expand Up @@ -187,7 +195,11 @@ func (w *WorkerRelay) wgStateCheck(ctx context.Context, ctxCancel context.Cancel
w.relayLock.Lock()
_ = w.relayedConn.Close()
w.relayLock.Unlock()
w.callBacks.OnDisconnected()

select {
case w.EventChan <- RelayEvent{ConnStatus: StatusDisconnected}:
case <-ctx.Done():
}
return
}

Expand Down Expand Up @@ -225,12 +237,16 @@ func (w *WorkerRelay) wgState() (time.Time, error) {
return wgState.LastHandshake, nil
}

func (w *WorkerRelay) onRelayMGDisconnected() {
func (w *WorkerRelay) onRelayMGDisconnected(ctx context.Context) {
w.ctxLock.Lock()
defer w.ctxLock.Unlock()

if w.ctxCancelWgWatch != nil {
w.ctxCancelWgWatch()
}
go w.callBacks.OnDisconnected()

select {
case w.EventChan <- RelayEvent{ConnStatus: StatusDisconnected}:
case <-ctx.Done():
}
}
Loading