diff --git a/client.go b/client.go index 5b5460e..52d19ec 100644 --- a/client.go +++ b/client.go @@ -66,6 +66,7 @@ var ( ) type ClientOptions struct { + IceTrickle bool `json:"ice_trickle"` IdleTimeout time.Duration `json:"idle_timeout"` Type string `json:"type"` EnableVoiceDetection bool `json:"enable_voice_detection"` @@ -194,6 +195,7 @@ type Client struct { func DefaultClientOptions() ClientOptions { return ClientOptions{ + IceTrickle: true, IdleTimeout: 5 * time.Minute, Type: ClientTypePeer, EnableVoiceDetection: true, @@ -554,8 +556,8 @@ func NewClient(s *SFU, id string, name string, peerConnectionConfig webrtc.Confi track.OnEnded(func() { simulcastTrack := track.(*SimulcastTrack) - simulcastTrack.mu.Lock() - defer simulcastTrack.mu.Unlock() + // simulcastTrack.mu.Lock() + // defer simulcastTrack.mu.Unlock() if simulcastTrack.remoteTrackHigh != nil { client.stats.removeReceiverStats(simulcastTrack.remoteTrackHigh.track.ID() + simulcastTrack.remoteTrackHigh.track.RID()) } @@ -732,6 +734,11 @@ func (c *Client) Negotiate(offer webrtc.SessionDescription) (*webrtc.SessionDesc return nil, err } + if !c.options.IceTrickle { + gatherComplete := webrtc.GatheringCompletePromise(c.peerConnection.PC()) + <-gatherComplete + } + // allow add candidates once the local description is set c.canAddCandidate.Store(true) diff --git a/client_test.go b/client_test.go index ae0b278..5ae8aca 100644 --- a/client_test.go +++ b/client_test.go @@ -44,7 +44,7 @@ func TestTracksSubscribe(t *testing.T) { clients := make([]*Client, 0) for i := 0; i < peerCount; i++ { - pc, client, _, _ := CreatePeerPair(ctx, TestLogger, testRoom, DefaultTestIceServers(), fmt.Sprintf("peer-%d", i), true, false) + pc, client, _, _ := CreatePeerPair(ctx, TestLogger, testRoom, DefaultTestIceServers(), fmt.Sprintf("peer-%d", i), true, false, true) peers = append(peers, pc) clients = append(clients, client) @@ -202,7 +202,7 @@ Loop: } func addSimulcastPair(t *testing.T, ctx context.Context, room *Room, peerName string, simulcastTrackChan chan *SimulcastTrack) (*Client, *webrtc.PeerConnection) { - pc, client, _, _ := CreatePeerPair(ctx, TestLogger, room, DefaultTestIceServers(), peerName, true, true) + pc, client, _, _ := CreatePeerPair(ctx, TestLogger, room, DefaultTestIceServers(), peerName, true, true, true) client.OnTracksAvailable(func(availableTracks []ITrack) { for _, track := range availableTracks { if track.IsSimulcast() { @@ -256,6 +256,8 @@ func TestClientDataChannel(t *testing.T) { defer cancelTimeout() + defer pc.Close() + select { case <-timeout.Done(): t.Fatal("timeout waiting for data channel") @@ -263,7 +265,7 @@ func TestClientDataChannel(t *testing.T) { if state == webrtc.PeerConnectionStateConnected { _, _ = pc.CreateDataChannel("test", nil) - negotiate(pc, client, TestLogger) + negotiate(pc, client, TestLogger, true) } case dc := <-dcChan: require.Equal(t, "internal", dc.Label()) diff --git a/clienttracksimulcast.go b/clienttracksimulcast.go index 4591d67..962891a 100644 --- a/clienttracksimulcast.go +++ b/clienttracksimulcast.go @@ -281,9 +281,6 @@ func (t *simulcastClientTrack) onEnded() { return } - t.mu.Lock() - defer t.mu.Unlock() - for _, callback := range t.onTrackEndedCallbacks { callback() } diff --git a/go.mod b/go.mod index 4d9bd08..d9d35ce 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/pion/ice/v3 v3.0.16 github.com/pion/ice/v4 v4.0.2 github.com/pion/turn/v3 v3.0.3 - github.com/pion/webrtc/v4 v4.0.0-beta.34 + github.com/pion/webrtc/v4 v4.0.1 golang.org/x/exp v0.0.0-20230905200255-921286631fa9 golang.org/x/text v0.19.0 ) diff --git a/go.sum b/go.sum index afb12cc..d2db5da 100644 --- a/go.sum +++ b/go.sum @@ -57,6 +57,8 @@ github.com/pion/turn/v4 v4.0.0 h1:qxplo3Rxa9Yg1xXDxxH8xaqcyGUtbHYw4QSCvmFWvhM= github.com/pion/turn/v4 v4.0.0/go.mod h1:MuPDkm15nYSklKpN8vWJ9W2M0PlyQZqYt1McGuxG7mA= github.com/pion/webrtc/v4 v4.0.0-beta.34 h1:C5GPomCKm5Xc3iGUsoMGq1oEmv9GYIeadDsel7Qw8B0= github.com/pion/webrtc/v4 v4.0.0-beta.34/go.mod h1:SfNn8CcFxR6OUVjLXVslAQ3a3994JhyE3Hw1jAuqEto= +github.com/pion/webrtc/v4 v4.0.1 h1:6Unwc6JzoTsjxetcAIoWH81RUM4K5dBc1BbJGcF9WVE= +github.com/pion/webrtc/v4 v4.0.1/go.mod h1:SfNn8CcFxR6OUVjLXVslAQ3a3994JhyE3Hw1jAuqEto= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/pkg/fakeclient/fakeclient.go b/pkg/fakeclient/fakeclient.go index cc65895..2a9578f 100644 --- a/pkg/fakeclient/fakeclient.go +++ b/pkg/fakeclient/fakeclient.go @@ -17,7 +17,7 @@ type FakeClient struct { } func Create(ctx context.Context, log logging.LeveledLogger, room *sfu.Room, iceServers []webrtc.ICEServer, id string, simulcast bool) *FakeClient { - pc, client, stats, _ := sfu.CreatePeerPair(ctx, log, room, iceServers, id, true, simulcast) + pc, client, stats, _ := sfu.CreatePeerPair(ctx, log, room, iceServers, id, true, simulcast, true) return &FakeClient{ ID: id, diff --git a/room_test.go b/room_test.go index 292f662..0a3eb1b 100644 --- a/room_test.go +++ b/room_test.go @@ -119,9 +119,9 @@ func TestRoomJoinLeftEvent(t *testing.T) { clients[client.ID()] = client }) - pc1, client1, _, _ := CreatePeerPair(ctx, TestLogger, testRoom, DefaultTestIceServers(), "peer1", false, false) - pc2, client2, _, _ := CreatePeerPair(ctx, TestLogger, testRoom, DefaultTestIceServers(), "peer1", false, false) - pc3, client3, _, _ := CreatePeerPair(ctx, TestLogger, testRoom, DefaultTestIceServers(), "peer1", false, false) + pc1, client1, _, _ := CreatePeerPair(ctx, TestLogger, testRoom, DefaultTestIceServers(), "peer1", false, false, true) + pc2, client2, _, _ := CreatePeerPair(ctx, TestLogger, testRoom, DefaultTestIceServers(), "peer1", false, false, false) + pc3, client3, _, _ := CreatePeerPair(ctx, TestLogger, testRoom, DefaultTestIceServers(), "peer1", false, false, true) defer pc1.PeerConnection.Close() defer pc2.PeerConnection.Close() @@ -204,7 +204,7 @@ func TestRoomStats(t *testing.T) { clients[client.ID()] = client }) - pc1, client1, statsGetter1, done1 := CreatePeerPair(ctx, TestLogger, testRoom, DefaultTestIceServers(), "peer1", false, false) + pc1, client1, statsGetter1, done1 := CreatePeerPair(ctx, TestLogger, testRoom, DefaultTestIceServers(), "peer1", false, false, true) client1.OnTracksAdded(func(addedTracks []ITrack) { setTracks := make(map[string]TrackType, 0) @@ -214,7 +214,7 @@ func TestRoomStats(t *testing.T) { client1.SetTracksSourceType(setTracks) }) - pc2, client2, statsGetter2, done2 := CreatePeerPair(ctx, TestLogger, testRoom, DefaultTestIceServers(), "peer2", false, false) + pc2, client2, statsGetter2, done2 := CreatePeerPair(ctx, TestLogger, testRoom, DefaultTestIceServers(), "peer2", false, false, true) client2.OnTracksAdded(func(addedTracks []ITrack) { setTracks := make(map[string]TrackType, 0) diff --git a/sfu_test.go b/sfu_test.go index 6d2d57d..dceae9f 100644 --- a/sfu_test.go +++ b/sfu_test.go @@ -41,7 +41,7 @@ func TestLeaveRoom(t *testing.T) { for i := 0; i < peerCount; i++ { go func(i int) { - pc, client, _, _ := CreatePeerPair(ctx, TestLogger, testRoom, DefaultTestIceServers(), fmt.Sprintf("peer-%d", i), true, false) + pc, client, _, _ := CreatePeerPair(ctx, TestLogger, testRoom, DefaultTestIceServers(), fmt.Sprintf("peer-%d", i), true, false, true) clients = append(clients, client) @@ -153,7 +153,7 @@ func TestRenegotiation(t *testing.T) { pairs := make([]Pair, 0) for i := 0; i < peerCount; i++ { - pc, client, _, _ := CreatePeerPair(ctx, TestLogger, testRoom, DefaultTestIceServers(), fmt.Sprintf("peer-%d", i), true, false) + pc, client, _, _ := CreatePeerPair(ctx, TestLogger, testRoom, DefaultTestIceServers(), fmt.Sprintf("peer-%d", i), true, false, true) pairs = append(pairs, Pair{pc.PeerConnection, client}) @@ -195,7 +195,7 @@ Loop: newTrack, _ := GetStaticVideoTrack(timeout, iceConnectedCtx, GenerateSecureToken(), GenerateSecureToken(), true, "") _, err := pair.pc.AddTransceiverFromTrack(newTrack) require.NoError(t, err, "error adding track: %v", err) - negotiate(pair.pc, pair.client, TestLogger) + negotiate(pair.pc, pair.client, TestLogger, true) } }() } diff --git a/testhelper.go b/testhelper.go index d3ae89e..6a48b56 100644 --- a/testhelper.go +++ b/testhelper.go @@ -369,7 +369,7 @@ func GetMediaEngine() *webrtc.MediaEngine { return mediaEngine } -func CreatePeerPair(ctx context.Context, log logging.LeveledLogger, room *Room, iceServers []webrtc.ICEServer, peerName string, loop, isSimulcast bool) (*PC, *Client, stats.Getter, chan bool) { +func CreatePeerPair(ctx context.Context, log logging.LeveledLogger, room *Room, iceServers []webrtc.ICEServer, peerName string, loop, isSimulcast bool, isIceTrickle bool) (*PC, *Client, stats.Getter, chan bool) { clientContext, cancelClient := context.WithCancel(ctx) var ( client *Client @@ -431,6 +431,10 @@ func CreatePeerPair(ctx context.Context, log logging.LeveledLogger, room *Room, _ = room.StopClient(client.ID()) } cancelClient() + + if state == webrtc.PeerConnectionStateFailed { + pc.Close() + } } }) @@ -510,7 +514,10 @@ func CreatePeerPair(ctx context.Context, log logging.LeveledLogger, room *Room, // add a new client to room // you can also get the client by using r.GetClient(clientID) id := room.CreateClientID() - client, _ = room.AddClient(id, id, DefaultClientOptions()) + opts := DefaultClientOptions() + opts.IceTrickle = isIceTrickle + + client, _ = room.AddClient(id, id, opts) client.OnTracksAdded(func(addedTracks []ITrack) { setTracks := make(map[string]TrackType, 0) @@ -565,30 +572,32 @@ func CreatePeerPair(ctx context.Context, log logging.LeveledLogger, room *Room, client.OnAllowedRemoteRenegotiation(func() { log.Infof("allowed remote renegotiation") - negotiate(pc, client, log) + negotiate(pc, client, log, isIceTrickle) }) - client.OnIceCandidate(func(ctx context.Context, candidate *webrtc.ICECandidate) { - if candidate == nil { - return - } + if isIceTrickle { + client.OnIceCandidate(func(ctx context.Context, candidate *webrtc.ICECandidate) { + if candidate == nil { + return + } - _ = pc.AddICECandidate(candidate.ToJSON()) - }) + _ = pc.AddICECandidate(candidate.ToJSON()) + }) - negotiate(pc, client, log) + pc.OnICECandidate(func(candidate *webrtc.ICECandidate) { + if candidate == nil { + return + } + err = client.PeerConnection().PC().AddICECandidate(candidate.ToJSON()) + }) + } - pc.OnICECandidate(func(candidate *webrtc.ICECandidate) { - if candidate == nil { - return - } - err = client.PeerConnection().PC().AddICECandidate(candidate.ToJSON()) - }) + negotiate(pc, client, log, isIceTrickle) return peer, client, statsGetter, allDone } -func negotiate(pc *webrtc.PeerConnection, client *Client, log logging.LeveledLogger) { +func negotiate(pc *webrtc.PeerConnection, client *Client, log logging.LeveledLogger, iceTrickle bool) { if pc.SignalingState() != webrtc.SignalingStateStable { log.Infof("test: signaling state is not stable, skip renegotiation") return @@ -603,7 +612,12 @@ func negotiate(pc *webrtc.PeerConnection, client *Client, log logging.LeveledLog _ = pc.SetLocalDescription(offer) - answer, _ := client.Negotiate(offer) + if !iceTrickle { + gatheringComplete := webrtc.GatheringCompletePromise(pc) + <-gatheringComplete + } + + answer, _ := client.Negotiate(*pc.LocalDescription()) if answer != nil { _ = pc.SetRemoteDescription(*answer) } @@ -664,9 +678,19 @@ func CreateDataPair(ctx context.Context, log logging.LeveledLogger, room *Room, if client != nil { _ = room.StopClient(client.ID()) } + + if state == webrtc.PeerConnectionStateFailed { + pc.Close() + } } - connChan <- state + ctxx, cancel := context.WithCancel(ctx) + defer cancel() + select { + case connChan <- state: + case <-ctxx.Done(): + + } }) @@ -677,7 +701,7 @@ func CreateDataPair(ctx context.Context, log logging.LeveledLogger, room *Room, client.OnAllowedRemoteRenegotiation(func() { log.Infof("allowed remote renegotiation") - go negotiate(pc, client, log) + go negotiate(pc, client, log, true) }) client.OnIceCandidate(func(ctx context.Context, candidate *webrtc.ICECandidate) { @@ -702,7 +726,7 @@ func CreateDataPair(ctx context.Context, log logging.LeveledLogger, room *Room, return *pc.LocalDescription(), nil }) - negotiate(pc, client, log) + negotiate(pc, client, log, true) pc.OnICECandidate(func(candidate *webrtc.ICECandidate) { if candidate == nil { diff --git a/track.go b/track.go index 17b450c..79fac86 100644 --- a/track.go +++ b/track.go @@ -611,9 +611,7 @@ func (t *SimulcastTrack) AddRemoteTrack(track IRemoteTrack, minWait, maxWait tim t.mu.Unlock() remoteTrack.OnEnded(func() { - t.mu.Lock() t.remoteTrackHigh = nil - t.mu.Unlock() t.cancel() t.onEnded() }) @@ -624,9 +622,7 @@ func (t *SimulcastTrack) AddRemoteTrack(track IRemoteTrack, minWait, maxWait tim t.mu.Unlock() remoteTrack.OnEnded(func() { - t.mu.Lock() t.remoteTrackMid = nil - t.mu.Unlock() t.cancel() t.onEnded() }) @@ -637,9 +633,7 @@ func (t *SimulcastTrack) AddRemoteTrack(track IRemoteTrack, minWait, maxWait tim t.mu.Unlock() remoteTrack.OnEnded(func() { - t.mu.Lock() t.remoteTrackLow = nil - t.mu.Unlock() t.cancel() t.onEnded() }) @@ -931,9 +925,6 @@ func (t *SimulcastTrack) OnEnded(f func()) { } func (t *SimulcastTrack) onEnded() { - t.mu.Lock() - defer t.mu.Unlock() - for _, f := range t.onEndedCallbacks { f() } diff --git a/track_test.go b/track_test.go index 2ecf8db..fd01bed 100644 --- a/track_test.go +++ b/track_test.go @@ -71,7 +71,7 @@ func createPeerAudio(ctx context.Context, room *Room, iceServers []webrtc.ICESer client.OnAllowedRemoteRenegotiation(func() { TestLogger.Info("allowed remote renegotiation") - negotiate(pc, client, TestLogger) + negotiate(pc, client, TestLogger, true) }) client.OnIceCandidate(func(ctx context.Context, candidate *webrtc.ICECandidate) { @@ -108,7 +108,7 @@ func createPeerAudio(ctx context.Context, room *Room, iceServers []webrtc.ICESer return *pc.LocalDescription(), nil }) - negotiate(pc, client, TestLogger) + negotiate(pc, client, TestLogger, true) pc.OnICECandidate(func(candidate *webrtc.ICECandidate) { if candidate == nil {