From f1a5f10ee527fc94e62039eab1128526299d5d4d Mon Sep 17 00:00:00 2001 From: Yohan Totting Date: Mon, 20 Nov 2023 15:52:17 +0700 Subject: [PATCH 01/13] refactor and add IRemoteTrack for relay track --- client.go | 85 ++++++++++++++++++-------------------------------- client_test.go | 4 +-- remotetrack.go | 79 ++++++++++++++++++++++++---------------------- sfu.go | 52 ++++++++++++++++++++++++++++-- track.go | 75 ++++++++++++++++++++++++++++---------------- 5 files changed, 172 insertions(+), 123 deletions(-) diff --git a/client.go b/client.go index 429523b..987162d 100644 --- a/client.go +++ b/client.go @@ -329,9 +329,33 @@ func NewClient(s *SFU, id string, name string, peerConnectionConfig webrtc.Confi vad = vadInterceptor.AddAudioTrack(remoteTrack) } + onPLI := func() error { + if client.peerConnection == nil || client.peerConnection.PC() == nil || client.peerConnection.PC().ConnectionState() != webrtc.PeerConnectionStateConnected { + return nil + } + + return client.peerConnection.PC().WriteRTCP([]rtcp.Packet{ + &rtcp.PictureLossIndication{MediaSSRC: uint32(remoteTrack.SSRC())}, + }) + } + + onStatsUpdated := func(stats *stats.Stats) { + client.mu.Lock() + defer client.mu.Unlock() + + client.stats.SetReceiver(track.ID(), *stats) + glog.Info("client: stats updated ", track.ID(), " ", stats) + } + if remoteTrack.RID() == "" { // not simulcast - track = newTrack(client, remoteTrack, receiver, vad) + + track = newTrack(client.context, client.id, remoteTrack, receiver, s.pliInterval, onPLI, vad, client.statsGetter, onStatsUpdated) + + track.OnEnded(func() { + client.stats.removeReceiverStats(remoteTrack.ID()) + }) + if err := client.tracks.Add(track); err != nil { glog.Error("client: error add track ", err) } @@ -349,12 +373,16 @@ func NewClient(s *SFU, id string, name string, peerConnectionConfig webrtc.Confi if err != nil { // if track not found, add it - track = newSimulcastTrack(client, remoteTrack, receiver) + track = newSimulcastTrack(client.context, client.id, remoteTrack, receiver, s.pliInterval, onPLI, client.statsGetter, onStatsUpdated) if err := client.tracks.Add(track); err != nil { glog.Error("client: error add track ", err) } + + track.OnEnded(func() { + client.stats.removeReceiverStats(remoteTrack.ID()) + }) } else if simulcast, ok = track.(*simulcastTrack); ok { - simulcast.AddRemoteTrack(remoteTrack, receiver) + simulcast.AddRemoteTrack(remoteTrack, receiver, client.statsGetter, onStatsUpdated) } // // only process track when the highest quality is available @@ -670,31 +698,6 @@ func (c *Client) setClientTrack(t ITrack) iClientTrack { return nil } - t.Client().OnLeft(func() { - if c == nil { - return - } - - c.mu.Lock() - defer c.mu.Unlock() - - sender := transc.Sender() - if sender == nil { - return - } - - if c.peerConnection == nil || c.peerConnection.PC() == nil || sender == nil { - return - } - - if err := c.peerConnection.PC().RemoveTrack(sender); err != nil { - glog.Error("client: error remove track ", err) - return - } - - c.renegotiate() - }) - t.OnEnded(func() { if c == nil { return @@ -940,32 +943,6 @@ func (c *Client) Stats() *ClientStats { return c.stats } -func (c *Client) updateReceiverStats(remoteTrack *remoteTrack) { - c.mu.Lock() - defer c.mu.Unlock() - - if c.statsGetter == nil { - return - } - - if remoteTrack.track == nil { - return - } - - track := remoteTrack.track - - if track.SSRC() == 0 { - return - } - - stats := c.statsGetter.Get(uint32(track.SSRC())) - if stats != nil { - remoteTrack.setReceiverStats(*stats) - c.stats.SetReceiver(track.ID(), *stats) - } - -} - func (c *Client) updateSenderStats(sender *webrtc.RTPSender) { c.mu.Lock() defer c.mu.Unlock() diff --git a/client_test.go b/client_test.go index df7a05b..5da054f 100644 --- a/client_test.go +++ b/client_test.go @@ -37,7 +37,7 @@ func TestTracksManualSubscribe(t *testing.T) { tracksReq := make([]SubscribeTrackRequest, 0) for _, track := range availableTracks { tracksReq = append(tracksReq, SubscribeTrackRequest{ - ClientID: track.Client().ID(), + ClientID: track.ClientID(), TrackID: track.ID(), }) } @@ -209,7 +209,7 @@ func addSimulcastPair(t *testing.T, ctx context.Context, room *Room, peerName st tracksReq := make([]SubscribeTrackRequest, 0) for _, track := range availableTracks { tracksReq = append(tracksReq, SubscribeTrackRequest{ - ClientID: track.Client().ID(), + ClientID: track.ClientID(), TrackID: track.ID(), }) } diff --git a/remotetrack.go b/remotetrack.go index 92fc387..8462aae 100644 --- a/remotetrack.go +++ b/remotetrack.go @@ -10,53 +10,72 @@ import ( "github.com/golang/glog" "github.com/pion/interceptor/pkg/stats" - "github.com/pion/rtcp" "github.com/pion/rtp" "github.com/pion/webrtc/v3" ) type remoteTrack struct { - client *Client context context.Context cancel context.CancelFunc mu sync.Mutex - track *webrtc.TrackRemote + track IRemoteTrack receiver *webrtc.RTPReceiver - onRead func(*rtp.Packet) + onReadCallbacks []func(*rtp.Packet) + onPLI func() error bitrate *atomic.Uint32 previousBytesReceived *atomic.Uint64 currentBytesReceived *atomic.Uint64 latestUpdatedTS *atomic.Uint64 lastPLIRequestTime time.Time onEndedCallbacks []func() - stats stats.Stats + statsGetter stats.Getter + onStatsUpdated func(*stats.Stats) } -func newRemoteTrack(client *Client, track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver, onRead func(*rtp.Packet)) *remoteTrack { - ctx, cancel := context.WithCancel(client.context) +func newRemoteTrack(ctx context.Context, track IRemoteTrack, receiver *webrtc.RTPReceiver, pliInterval time.Duration, onPLI func() error, statsGetter stats.Getter, onStatsUpdated func(*stats.Stats)) *remoteTrack { + localctx, cancel := context.WithCancel(ctx) rt := &remoteTrack{ - context: ctx, + context: localctx, cancel: cancel, - client: client, mu: sync.Mutex{}, track: track, receiver: receiver, - onRead: onRead, bitrate: &atomic.Uint32{}, previousBytesReceived: &atomic.Uint64{}, currentBytesReceived: &atomic.Uint64{}, latestUpdatedTS: &atomic.Uint64{}, onEndedCallbacks: make([]func(), 0), - stats: stats.Stats{}, + statsGetter: statsGetter, + onStatsUpdated: onStatsUpdated, + onPLI: onPLI, + onReadCallbacks: make([]func(*rtp.Packet), 0), } - rt.enableIntervalPLI(client.sfu.PLIInterval()) + rt.enableIntervalPLI(pliInterval) rt.readRTP() return rt } +func (t *remoteTrack) OnRead(f func(*rtp.Packet)) { + t.mu.Lock() + defer t.mu.Unlock() + + t.onReadCallbacks = append(t.onReadCallbacks, f) +} + +func (t *remoteTrack) onRead(p *rtp.Packet) { + t.mu.Lock() + defer t.mu.Unlock() + + for _, f := range t.onReadCallbacks { + f(p) + } + + go t.updateStats() +} + func (t *remoteTrack) OnEnded(f func()) { t.mu.Lock() defer t.mu.Unlock() @@ -84,8 +103,6 @@ func (t *remoteTrack) readRTP() { if readErr == io.EOF { t.onEnded() - t.client.stats.removeReceiverStats(t.track.ID()) - return } else if readErr != nil { glog.Error("error reading rtp: ", readErr.Error()) @@ -93,16 +110,18 @@ func (t *remoteTrack) readRTP() { } t.onRead(rtp) - - go t.client.updateReceiverStats(t) - } } }() } func (t *remoteTrack) updateStats() { - s := t.stats + s := t.statsGetter.Get(uint32(t.track.SSRC())) + if s == nil { + glog.Warning("remotetrack: stats not found for track: ", t.track.SSRC()) + return + } + // update the stats if the last update equal or more than 1 second latestUpdated := t.latestUpdatedTS.Load() if time.Since(time.Unix(0, int64(latestUpdated))).Seconds() <= 1 { @@ -123,9 +142,12 @@ func (t *remoteTrack) updateStats() { t.bitrate.Store(uint32((s.BytesReceived-current)*8) / uint32(deltaTime.Seconds())) + if t.onStatsUpdated != nil { + t.onStatsUpdated(s) + } } -func (t *remoteTrack) Track() *webrtc.TrackRemote { +func (t *remoteTrack) Track() IRemoteTrack { return t.track } @@ -133,21 +155,6 @@ func (t *remoteTrack) GetCurrentBitrate() uint32 { return t.bitrate.Load() } -func (t *remoteTrack) receiverStats() stats.Stats { - t.mu.Lock() - defer t.mu.Unlock() - - return t.stats -} - -func (t *remoteTrack) setReceiverStats(s stats.Stats) { - t.mu.Lock() - defer t.mu.Unlock() - t.stats = s - - t.updateStats() -} - func (t *remoteTrack) sendPLI() error { t.mu.Lock() defer t.mu.Unlock() @@ -161,9 +168,7 @@ func (t *remoteTrack) sendPLI() error { t.lastPLIRequestTime = time.Now() - return t.client.peerConnection.PC().WriteRTCP([]rtcp.Packet{ - &rtcp.PictureLossIndication{MediaSSRC: uint32(t.track.SSRC())}, - }) + return t.onPLI() } func (t *remoteTrack) enableIntervalPLI(interval time.Duration) { diff --git a/sfu.go b/sfu.go index ac78a97..b955085 100644 --- a/sfu.go +++ b/sfu.go @@ -220,7 +220,7 @@ func (s *SFU) NewClient(id, name string, opts ClientOptions) *Client { for _, c := range s.clients.GetClients() { for _, track := range c.tracks.GetTracks() { - if track.Client().ID() != client.ID() { + if track.ClientID() != client.ID() { availableTracks = append(availableTracks, track) } } @@ -397,7 +397,7 @@ func (s *SFU) onTracksAvailable(tracks []ITrack) { // filter out tracks from the same client filteredTracks := make([]ITrack, 0) for _, track := range tracks { - if track.Client().ID() != client.ID() { + if track.ClientID() != client.ID() { filteredTracks = append(filteredTracks, track) } } @@ -419,7 +419,7 @@ func (s *SFU) broadcastTracksToAutoSubscribeClients(ownerID string, tracks []ITr trackReq := make([]SubscribeTrackRequest, 0) for _, track := range tracks { trackReq = append(trackReq, SubscribeTrackRequest{ - ClientID: track.Client().ID(), + ClientID: track.ClientID(), TrackID: track.ID(), }) } @@ -582,3 +582,49 @@ func (s *SFU) OnTracksAvailable(callback func(tracks []ITrack)) { s.onTrackAvailableCallbacks = append(s.onTrackAvailableCallbacks, callback) } + +// func (s *SFU) AddRelayTrack(id, streamid, rid string, kind webrtc.RTPCodecType, ssrc webrtc.SSRC, mimeType string, rtpChan chan *rtp.Packet) error { +// var track ITrack + +// relayTrack := NewTrackRelay(id, streamid, rid, kind, ssrc, mimeType, rtpChan) + +// if rid == "" { +// // not simulcast +// track = newTrack(client, remoteTrack, receiver, vad) +// if err := client.tracks.Add(track); err != nil { +// glog.Error("client: error add track ", err) +// } + +// client.onTrack(track) +// track.SetAsProcessed() +// } else { +// // simulcast +// var simulcast *simulcastTrack +// var ok bool + +// id := remoteTrack.ID() + +// track, err = client.tracks.Get(id) // not found because the track is not added yet due to race condition + +// if err != nil { +// // if track not found, add it +// track = newSimulcastTrack(client, remoteTrack, receiver) +// if err := client.tracks.Add(track); err != nil { +// glog.Error("client: error add track ", err) +// } +// } else if simulcast, ok = track.(*simulcastTrack); ok { +// simulcast.AddRemoteTrack(remoteTrack, receiver) +// } + +// // // only process track when the highest quality is available +// // simulcast.mu.Lock() +// // isHighAvailable := simulcast.remoteTrackHigh != nil +// // simulcast.mu.Unlock() + +// if !track.IsProcessed() { +// client.onTrack(track) +// track.SetAsProcessed() +// } + +// } +// } diff --git a/track.go b/track.go index bfa3474..efa90ae 100644 --- a/track.go +++ b/track.go @@ -1,6 +1,7 @@ package sfu import ( + "context" "errors" "sync" "sync/atomic" @@ -8,6 +9,7 @@ import ( "github.com/golang/glog" "github.com/inlivedev/sfu/pkg/interceptors/voiceactivedetector" + "github.com/pion/interceptor/pkg/stats" "github.com/pion/rtp" "github.com/pion/webrtc/v3" ) @@ -32,7 +34,7 @@ type baseTrack struct { id string msid string streamid string - client *Client + clientid string isProcessed bool kind webrtc.RTPCodecType codec webrtc.RTPCodecParameters @@ -43,7 +45,7 @@ type baseTrack struct { type ITrack interface { ID() string StreamID() string - Client() *Client + ClientID() string IsSimulcast() bool IsScaleable() bool IsProcessed() bool @@ -67,7 +69,7 @@ type track struct { onReadCallbacks []func(*rtp.Packet, QualityLevel) } -func newTrack(client *Client, trackRemote *webrtc.TrackRemote, receiver *webrtc.RTPReceiver, vad *voiceactivedetector.VoiceDetector) ITrack { +func newTrack(ctx context.Context, clientID string, trackRemote *webrtc.TrackRemote, receiver *webrtc.RTPReceiver, pliInterval time.Duration, onPLI func() error, vad *voiceactivedetector.VoiceDetector, stats stats.Getter, onStatsUpdated func(*stats.Stats)) ITrack { ctList := newClientTrackList() baseTrack := baseTrack{ @@ -75,7 +77,7 @@ func newTrack(client *Client, trackRemote *webrtc.TrackRemote, receiver *webrtc. isScreen: &atomic.Bool{}, msid: trackRemote.Msid(), streamid: trackRemote.StreamID(), - client: client, + clientid: clientID, kind: trackRemote.Kind(), codec: trackRemote.Codec(), clientTracks: ctList, @@ -89,7 +91,8 @@ func newTrack(client *Client, trackRemote *webrtc.TrackRemote, receiver *webrtc. onEndedCallbacks: make([]func(), 0), } - onTrackRead := func(p *rtp.Packet) { + t.remoteTrack = newRemoteTrack(ctx, trackRemote, receiver, pliInterval, onPLI, stats, onStatsUpdated) + t.remoteTrack.OnRead(func(p *rtp.Packet) { // do tracks := ctList.GetTracks() for _, track := range tracks { @@ -97,9 +100,7 @@ func newTrack(client *Client, trackRemote *webrtc.TrackRemote, receiver *webrtc. } t.onRead(p, QualityHigh) - } - - t.remoteTrack = newRemoteTrack(client, trackRemote, receiver, onTrackRead) + }) t.remoteTrack.OnEnded(func() { t.onEnded() @@ -108,6 +109,10 @@ func newTrack(client *Client, trackRemote *webrtc.TrackRemote, receiver *webrtc. return t } +func (t *track) ClientID() string { + return t.base.clientid +} + func (t *track) createLocalTrack() *webrtc.TrackLocalStaticRTP { track, newTrackErr := webrtc.NewTrackLocalStaticRTP(t.remoteTrack.track.Codec().RTPCodecCapability, t.base.id, t.base.streamid) if newTrackErr != nil { @@ -125,10 +130,6 @@ func (t *track) StreamID() string { return t.base.streamid } -func (t *track) Client() *Client { - return t.base.client -} - func (t *track) RemoteTrack() *remoteTrack { t.mu.Lock() defer t.mu.Unlock() @@ -214,6 +215,13 @@ func (t *track) subscribe(c *Client) iClientTrack { ct.onTrackEnded() }) + go func() { + clientCtx, cancel := context.WithCancel(c.context) + defer cancel() + <-clientCtx.Done() + ct.onTrackEnded() + }() + t.base.clientTracks.Add(ct) return ct @@ -263,6 +271,7 @@ func (t *track) onRead(p *rtp.Packet, quality QualityLevel) { } type simulcastTrack struct { + context context.Context mu sync.Mutex base *baseTrack baseTS uint32 @@ -288,17 +297,20 @@ type simulcastTrack struct { onAddedRemoteTrackCallbacks []func(*remoteTrack) onEndedCallbacks []func() onReadCallbacks []func(*rtp.Packet, QualityLevel) + pliInterval time.Duration + onPLI func() error } -func newSimulcastTrack(client *Client, track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) ITrack { +func newSimulcastTrack(ctx context.Context, clientid string, track IRemoteTrack, receiver *webrtc.RTPReceiver, pliInterval time.Duration, onPLI func() error, stats stats.Getter, onStatsUpdated func(*stats.Stats)) ITrack { t := &simulcastTrack{ - mu: sync.Mutex{}, + context: ctx, + mu: sync.Mutex{}, base: &baseTrack{ id: track.ID(), isScreen: &atomic.Bool{}, msid: track.Msid(), streamid: track.StreamID(), - client: client, + clientid: clientid, kind: track.Kind(), codec: track.Codec(), clientTracks: newClientTrackList(), @@ -312,13 +324,19 @@ func newSimulcastTrack(client *Client, track *webrtc.TrackRemote, receiver *webr onAddedRemoteTrackCallbacks: make([]func(*remoteTrack), 0), onEndedCallbacks: make([]func(), 0), onReadCallbacks: make([]func(*rtp.Packet, QualityLevel), 0), + pliInterval: pliInterval, + onPLI: onPLI, } - _ = t.AddRemoteTrack(track, receiver) + _ = t.AddRemoteTrack(track, receiver, stats, onStatsUpdated) return t } +func (t *simulcastTrack) ClientID() string { + return t.base.clientid +} + func (t *simulcastTrack) onRemoteTrackAdded(f func(*remoteTrack)) { t.mu.Lock() defer t.mu.Unlock() @@ -351,10 +369,6 @@ func (t *simulcastTrack) StreamID() string { return t.base.streamid } -func (t *simulcastTrack) Client() *Client { - return t.base.client -} - func (t *simulcastTrack) IsSimulcast() bool { return true } @@ -374,7 +388,7 @@ func (t *simulcastTrack) Kind() webrtc.RTPCodecType { return t.base.kind } -func (t *simulcastTrack) AddRemoteTrack(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) *remoteTrack { +func (t *simulcastTrack) AddRemoteTrack(track IRemoteTrack, receiver *webrtc.RTPReceiver, stats stats.Getter, onStatsUpdated func(*stats.Stats)) *remoteTrack { var remoteTrack *remoteTrack quality := RIDToQuality(track.RID()) @@ -420,9 +434,11 @@ func (t *simulcastTrack) AddRemoteTrack(track *webrtc.TrackRemote, receiver *web t.mu.Lock() + remoteTrack = newRemoteTrack(t.context, track, receiver, t.pliInterval, t.onPLI, stats, onStatsUpdated) + remoteTrack.OnRead(onRead) switch quality { case QualityHigh: - remoteTrack = newRemoteTrack(t.base.client, track, receiver, onRead) + t.remoteTrackHigh = remoteTrack remoteTrack.OnEnded(func() { t.mu.Lock() @@ -435,7 +451,6 @@ func (t *simulcastTrack) AddRemoteTrack(track *webrtc.TrackRemote, receiver *web }) case QualityMid: - remoteTrack = newRemoteTrack(t.base.client, track, receiver, onRead) t.remoteTrackMid = remoteTrack remoteTrack.OnEnded(func() { t.mu.Lock() @@ -447,7 +462,6 @@ func (t *simulcastTrack) AddRemoteTrack(track *webrtc.TrackRemote, receiver *web } }) case QualityLow: - remoteTrack = newRemoteTrack(t.base.client, track, receiver, onRead) t.remoteTrackLow = remoteTrack remoteTrack.OnEnded(func() { t.mu.Lock() @@ -535,6 +549,13 @@ func (t *simulcastTrack) subscribe(client *Client) iClientTrack { }) } + go func() { + clientCtx, cancel := context.WithCancel(client.context) + defer cancel() + <-clientCtx.Done() + ct.onTrackEnded() + }() + if t.remoteTrackMid != nil { t.remoteTrackMid.OnEnded(func() { ct.onTrackEnded() @@ -604,7 +625,7 @@ func (t *simulcastTrack) isTrackActive(quality QualityLevel) bool { delta := time.Since(time.Unix(0, t.lastReadHighTS.Load())) if delta > threshold { - glog.Warningf("track: remote track %s high is not active, last read was %d ms ago", t.Client().ID(), delta.Milliseconds()) + glog.Warningf("track: remote track %s high is not active, last read was %d ms ago", delta.Milliseconds()) return false } @@ -617,7 +638,7 @@ func (t *simulcastTrack) isTrackActive(quality QualityLevel) bool { delta := time.Since(time.Unix(0, t.lastReadMidTS.Load())) if delta > threshold { - glog.Warningf("track: remote track %s mid is not active, last read was %d ms ago", t.Client().ID(), delta.Milliseconds()) + glog.Warningf("track: remote track %s mid is not active, last read was %d ms ago", delta.Milliseconds()) return false } @@ -630,7 +651,7 @@ func (t *simulcastTrack) isTrackActive(quality QualityLevel) bool { delta := time.Since(time.Unix(0, t.lastReadLowTS.Load())) if delta > threshold { - glog.Warningf("track: remote track %s low is not active, last read was %d ms ago", t.Client().ID(), delta.Milliseconds()) + glog.Warningf("track: remote track %s low is not active, last read was %d ms ago", delta.Milliseconds()) return false } From 704fca1f94b647a98efd8962be8c1d8c18040f0c Mon Sep 17 00:00:00 2001 From: Yohan Totting Date: Tue, 21 Nov 2023 21:32:46 +0700 Subject: [PATCH 02/13] latest state --- client.go | 42 +-- client_test.go | 1 + clientracklist.go | 7 +- examples/http-websocket/main.go | 1 + main_test.go | 4 +- .../voiceactivedetector/interceptor.go | 79 +++--- relaytrack.go | 123 +++++++++ remotetrack.go | 5 +- scalableclienttrack.go | 2 +- sfu.go | 87 ++++--- simulcastclienttrack.go | 27 +- track.go | 239 +++++++++++++----- 12 files changed, 418 insertions(+), 199 deletions(-) create mode 100644 relaytrack.go diff --git a/client.go b/client.go index 987162d..721dfa5 100644 --- a/client.go +++ b/client.go @@ -147,6 +147,7 @@ type Client struct { ingressBandwidth *atomic.Uint32 ingressQualityLimitationReason *atomic.Value isDebug bool + vad *voiceactivedetector.Interceptor } func DefaultClientOptions() ClientOptions { @@ -160,7 +161,8 @@ func DefaultClientOptions() ClientOptions { func NewClient(s *SFU, id string, name string, peerConnectionConfig webrtc.Configuration, opts ClientOptions) *Client { var client *Client - var vadInterceptor *voiceactivedetector.Interceptor + + var vad *voiceactivedetector.Interceptor localCtx, cancel := context.WithCancel(s.context) m := &webrtc.MediaEngine{} @@ -198,7 +200,7 @@ func NewClient(s *SFU, id string, name string, peerConnectionConfig webrtc.Confi // enable voice detector vadInterceptorFactory.OnNew(func(i *voiceactivedetector.Interceptor) { - vadInterceptor = i + vad = i }) i.Add(vadInterceptorFactory) @@ -287,6 +289,7 @@ func NewClient(s *SFU, id string, name string, peerConnectionConfig webrtc.Confi ingressBandwidth: &atomic.Uint32{}, ingressQualityLimitationReason: &atomic.Value{}, onTracksAvailableCallbacks: make([]func([]ITrack), 0), + vad: vad, } // setup internal data channel @@ -306,6 +309,9 @@ func NewClient(s *SFU, id string, name string, peerConnectionConfig webrtc.Confi if s.enableBandwidthEstimator { go func() { estimator := <-estimatorChan + client.mu.Lock() + defer client.mu.Unlock() + client.estimator = estimator }() } @@ -321,14 +327,9 @@ func NewClient(s *SFU, id string, name string, peerConnectionConfig webrtc.Confi // to connected peers peerConnection.OnTrack(func(remoteTrack *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { var track ITrack - var vad *voiceactivedetector.VoiceDetector defer glog.Info("client: new track ", remoteTrack.ID(), " Kind:", remoteTrack.Kind(), " Codec: ", remoteTrack.Codec().MimeType, " RID: ", remoteTrack.RID()) - if remoteTrack.Kind() == webrtc.RTPCodecTypeAudio && client.IsVADEnabled() { - vad = vadInterceptor.AddAudioTrack(remoteTrack) - } - onPLI := func() error { if client.peerConnection == nil || client.peerConnection.PC() == nil || client.peerConnection.PC().ConnectionState() != webrtc.PeerConnectionStateConnected { return nil @@ -344,13 +345,12 @@ func NewClient(s *SFU, id string, name string, peerConnectionConfig webrtc.Confi defer client.mu.Unlock() client.stats.SetReceiver(track.ID(), *stats) - glog.Info("client: stats updated ", track.ID(), " ", stats) } if remoteTrack.RID() == "" { // not simulcast - track = newTrack(client.context, client.id, remoteTrack, receiver, s.pliInterval, onPLI, vad, client.statsGetter, onStatsUpdated) + track = newTrack(client.context, client.id, remoteTrack, s.pliInterval, onPLI, client.statsGetter, onStatsUpdated) track.OnEnded(func() { client.stats.removeReceiverStats(remoteTrack.ID()) @@ -364,7 +364,7 @@ func NewClient(s *SFU, id string, name string, peerConnectionConfig webrtc.Confi track.SetAsProcessed() } else { // simulcast - var simulcast *simulcastTrack + var simulcast *SimulcastTrack var ok bool id := remoteTrack.ID() @@ -373,7 +373,7 @@ func NewClient(s *SFU, id string, name string, peerConnectionConfig webrtc.Confi if err != nil { // if track not found, add it - track = newSimulcastTrack(client.context, client.id, remoteTrack, receiver, s.pliInterval, onPLI, client.statsGetter, onStatsUpdated) + track = newSimulcastTrack(client.context, client.id, remoteTrack, s.pliInterval, onPLI, client.statsGetter, onStatsUpdated) if err := client.tracks.Add(track); err != nil { glog.Error("client: error add track ", err) } @@ -381,8 +381,8 @@ func NewClient(s *SFU, id string, name string, peerConnectionConfig webrtc.Confi track.OnEnded(func() { client.stats.removeReceiverStats(remoteTrack.ID()) }) - } else if simulcast, ok = track.(*simulcastTrack); ok { - simulcast.AddRemoteTrack(remoteTrack, receiver, client.statsGetter, onStatsUpdated) + } else if simulcast, ok = track.(*SimulcastTrack); ok { + simulcast.AddRemoteTrack(remoteTrack, client.statsGetter, onStatsUpdated) } // // only process track when the highest quality is available @@ -592,7 +592,6 @@ func (c *Client) renegotiateQueuOp() { // no need to run another negotiation if it's already in progress, it will rerun because we mark the negotiationneeded to true if c.isInRenegotiation.Load() { - glog.Info("sfu: renegotiation can't run, renegotiation still in progress ", c.ID) return } @@ -682,11 +681,11 @@ func (c *Client) setClientTrack(t ITrack) iClientTrack { } if t.IsSimulcast() { - simulcastTrack := t.(*simulcastTrack) + simulcastTrack := t.(*SimulcastTrack) outputTrack = simulcastTrack.subscribe(c) } else { - singleTrack := t.(*track) + singleTrack := t.(*Track) outputTrack = singleTrack.subscribe(c) } @@ -1015,6 +1014,17 @@ func (c *Client) SubscribeTracks(req []SubscribeTrackRequest) error { } } + + // look on relay tracks + for _, track := range c.SFU().relayTracks { + if track.ID() == r.TrackID { + if clientTrack := c.setClientTrack(track); clientTrack != nil { + clientTracks = append(clientTracks, clientTrack) + } + + trackFound = true + } + } } else if err != nil { return err } diff --git a/client_test.go b/client_test.go index 5da054f..d4ef9d0 100644 --- a/client_test.go +++ b/client_test.go @@ -192,6 +192,7 @@ Loop: select { case <-timeout.Done(): t.Fatal("timeout waiting for track added") + break Loop case <-trackChan: trackCount++ glog.Info("track added ", trackCount) diff --git a/clientracklist.go b/clientracklist.go index 720482b..d5ef37d 100644 --- a/clientracklist.go +++ b/clientracklist.go @@ -34,16 +34,13 @@ func (l *clientTrackList) Get(id string) iClientTrack { l.mu.Lock() defer l.mu.Unlock() - var track iClientTrack - for _, t := range l.tracks { if t.ID() == id { - track = t - break + return t } } - return track + return nil } func (l *clientTrackList) Length() int { diff --git a/examples/http-websocket/main.go b/examples/http-websocket/main.go index 2a78084..15a92a1 100644 --- a/examples/http-websocket/main.go +++ b/examples/http-websocket/main.go @@ -88,6 +88,7 @@ func main() { // create new room roomsOpts := sfu.DefaultRoomOptions() + roomsOpts.Bitrates.InitialBandwidth = 1_000_000 roomsOpts.PLIInterval = 3 * time.Second roomsOpts.Codecs = []string{webrtc.MimeTypeVP9, webrtc.MimeTypeH264, webrtc.MimeTypeOpus} defaultRoom, _ := roomManager.NewRoom(roomID, roomName, sfu.RoomTypeLocal, roomsOpts) diff --git a/main_test.go b/main_test.go index 9181ac9..657d5e8 100644 --- a/main_test.go +++ b/main_test.go @@ -11,8 +11,8 @@ import ( var roomManager *Manager func TestMain(m *testing.M) { - // flag.Set("logtostderr", "true") - // flag.Set("stderrthreshold", "INFO") + flag.Set("logtostderr", "true") + flag.Set("stderrthreshold", "INFO") flag.Parse() ctx, cancel := context.WithCancel(context.Background()) diff --git a/pkg/interceptors/voiceactivedetector/interceptor.go b/pkg/interceptors/voiceactivedetector/interceptor.go index 9b57230..866aab8 100644 --- a/pkg/interceptors/voiceactivedetector/interceptor.go +++ b/pkg/interceptors/voiceactivedetector/interceptor.go @@ -57,7 +57,7 @@ func DefaultConfig() Config { type Interceptor struct { context context.Context mu sync.Mutex - vads map[uint32]*VoiceDetector + vads map[string]*VoiceDetector config Config } @@ -66,7 +66,7 @@ func new(ctx context.Context) *Interceptor { context: ctx, mu: sync.Mutex{}, config: DefaultConfig(), - vads: make(map[uint32]*VoiceDetector), + vads: make(map[string]*VoiceDetector), } } @@ -77,22 +77,11 @@ func (v *Interceptor) SetConfig(config Config) { // BindLocalStream lets you modify any outgoing RTP packets. It is called once for per LocalStream. The returned method // will be called once per rtp packet. func (v *Interceptor) BindLocalStream(info *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter { - return writer -} - -// UnbindLocalStream is called when the Stream is removed. It can be used to clean up any data related to that track. -func (v *Interceptor) UnbindLocalStream(info *interceptor.StreamInfo) { - -} - -// BindRemoteStream lets you modify any incoming RTP packets. It is called once for per RemoteStream. The returned method -// will be called once per rtp packet. -func (v *Interceptor) BindRemoteStream(info *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader { if info.MimeType != webrtc.MimeTypeOpus { - return reader + return writer } - vad := v.getVadBySSRC(info.SSRC) + vad := v.getVadByID(info.ID) if vad != nil { vad.updateStreamInfo(info) } @@ -101,30 +90,34 @@ func (v *Interceptor) BindRemoteStream(info *interceptor.StreamInfo, reader inte defer v.mu.Unlock() if vad == nil { - v.vads[info.SSRC] = newVAD(v.context, v, info) + v.vads[info.ID] = newVAD(v.context, v, info) } - return interceptor.RTPReaderFunc(func(bytes []byte, attributes interceptor.Attributes) (int, interceptor.Attributes, error) { - n, a, err := reader.Read(bytes, attributes) - if err != nil { - p := rtp.Packet{} - if errUnmarshal := p.Unmarshal(bytes); errUnmarshal == nil { - _ = v.processPacket(info.SSRC, &p.Header) - } - } - - return n, a, err + return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { + _ = v.processPacket(info.ID, header) + return writer.Write(header, payload, attributes) }) } -func (v *Interceptor) UnbindRemoteStream(info *interceptor.StreamInfo) { - vad := v.getVadBySSRC(info.SSRC) +// UnbindLocalStream is called when the Stream is removed. It can be used to clean up any data related to that track. +func (v *Interceptor) UnbindLocalStream(info *interceptor.StreamInfo) { + vad := v.getVadByID(info.ID) if vad != nil { vad.Stop() } - delete(v.vads, info.SSRC) + delete(v.vads, info.ID) +} + +// BindRemoteStream lets you modify any incoming RTP packets. It is called once for per RemoteStream. The returned method +// will be called once per rtp packet. +func (v *Interceptor) BindRemoteStream(info *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader { + return reader +} + +func (v *Interceptor) UnbindRemoteStream(info *interceptor.StreamInfo) { + } func (v *Interceptor) Close() error { @@ -144,11 +137,11 @@ func (v *Interceptor) BindRTCPWriter(writer interceptor.RTCPWriter) interceptor. return writer } -func (v *Interceptor) getVadBySSRC(ssrc uint32) *VoiceDetector { +func (v *Interceptor) getVadByID(id string) *VoiceDetector { v.mu.Lock() defer v.mu.Unlock() - vad, ok := v.vads[ssrc] + vad, ok := v.vads[id] if ok { return vad } @@ -156,15 +149,15 @@ func (v *Interceptor) getVadBySSRC(ssrc uint32) *VoiceDetector { return nil } -func (v *Interceptor) processPacket(ssrc uint32, header *rtp.Header) rtp.AudioLevelExtension { - audioData := v.getAudioLevel(ssrc, header) +func (v *Interceptor) processPacket(id string, header *rtp.Header) rtp.AudioLevelExtension { + audioData := v.getAudioLevel(id, header) if audioData.Level == 0 { return rtp.AudioLevelExtension{} } - vad := v.getVadBySSRC(ssrc) + vad := v.getVadByID(id) if vad == nil { - glog.Error("vad: not found vad for track ssrc", ssrc) + glog.Error("vad: not found vad for track id", id) return rtp.AudioLevelExtension{} } @@ -180,9 +173,9 @@ func (v *Interceptor) getConfig() Config { return v.config } -func (v *Interceptor) getAudioLevel(ssrc uint32, header *rtp.Header) rtp.AudioLevelExtension { +func (v *Interceptor) getAudioLevel(id string, header *rtp.Header) rtp.AudioLevelExtension { audioLevel := rtp.AudioLevelExtension{} - headerID := v.getAudioLevelExtensionID(ssrc) + headerID := v.getAudioLevelExtensionID(id) if headerID != 0 { ext := header.GetExtension(headerID) _ = audioLevel.Unmarshal(ext) @@ -197,11 +190,11 @@ func RegisterAudioLevelHeaderExtension(m *webrtc.MediaEngine) { } } -func (v *Interceptor) getAudioLevelExtensionID(ssrc uint32) uint8 { +func (v *Interceptor) getAudioLevelExtensionID(id string) uint8 { v.mu.Lock() defer v.mu.Unlock() - vad, ok := v.vads[ssrc] + vad, ok := v.vads[id] if ok { for _, extension := range vad.streamInfo.RTPHeaderExtensions { if extension.URI == sdp.AudioLevelURI { @@ -214,19 +207,17 @@ func (v *Interceptor) getAudioLevelExtensionID(ssrc uint32) uint8 { } // AddAudioTrack adds audio track to interceptor -func (v *Interceptor) AddAudioTrack(t *webrtc.TrackRemote) *VoiceDetector { +func (v *Interceptor) AddAudioTrack(t webrtc.TrackLocal) *VoiceDetector { if t.Kind() != webrtc.RTPCodecTypeAudio { glog.Error("vad: track is not audio track") return nil } - ssrc := uint32(t.SSRC()) - - vad := v.getVadBySSRC(ssrc) + vad := v.getVadByID(t.ID()) if vad == nil { v.mu.Lock() vad = newVAD(v.context, v, nil) - v.vads[ssrc] = vad + v.vads[t.ID()] = vad v.mu.Unlock() } diff --git a/relaytrack.go b/relaytrack.go new file mode 100644 index 0000000..b94ca82 --- /dev/null +++ b/relaytrack.go @@ -0,0 +1,123 @@ +package sfu + +import ( + "errors" + "sync" + "time" + + "github.com/pion/interceptor" + "github.com/pion/rtp" + "github.com/pion/webrtc/v3" +) + +type IRemoteTrack interface { + ID() string + RID() string + PayloadType() webrtc.PayloadType + Kind() webrtc.RTPCodecType + StreamID() string + SSRC() webrtc.SSRC + Msid() string + Codec() webrtc.RTPCodecParameters + Read(b []byte) (n int, attributes interceptor.Attributes, err error) + ReadRTP() (*rtp.Packet, interceptor.Attributes, error) + SetReadDeadline(deadline time.Time) error +} + +// TrackRemote represents a single inbound source of media +type RelayTrack struct { + mu sync.RWMutex + + id string + streamID string + + payloadType webrtc.PayloadType + kind webrtc.RTPCodecType + ssrc webrtc.SSRC + codec webrtc.RTPCodecParameters + rid string + rtpChan chan *rtp.Packet +} + +func NewTrackRelay(id, streamid, rid string, kind webrtc.RTPCodecType, ssrc webrtc.SSRC, mimeType string, rtpChan chan *rtp.Packet) IRemoteTrack { + return &RelayTrack{ + kind: kind, + ssrc: ssrc, + rid: rid, + } +} + +// ID is the unique identifier for this Track. This should be unique for the +// stream, but doesn't have to globally unique. A common example would be 'audio' or 'video' +// and StreamID would be 'desktop' or 'webcam' +func (t *RelayTrack) ID() string { + t.mu.RLock() + defer t.mu.RUnlock() + return t.id +} + +// RID gets the RTP Stream ID of this Track +// With Simulcast you will have multiple tracks with the same ID, but different RID values. +// In many cases a TrackRemote will not have an RID, so it is important to assert it is non-zero +func (t *RelayTrack) RID() string { + t.mu.RLock() + defer t.mu.RUnlock() + + return t.rid +} + +// PayloadType gets the PayloadType of the track +func (t *RelayTrack) PayloadType() webrtc.PayloadType { + t.mu.RLock() + defer t.mu.RUnlock() + return t.payloadType +} + +// Kind gets the Kind of the track +func (t *RelayTrack) Kind() webrtc.RTPCodecType { + t.mu.RLock() + defer t.mu.RUnlock() + return t.kind +} + +// StreamID is the group this track belongs too. This must be unique +func (t *RelayTrack) StreamID() string { + t.mu.RLock() + defer t.mu.RUnlock() + return t.streamID +} + +// SSRC gets the SSRC of the track +func (t *RelayTrack) SSRC() webrtc.SSRC { + t.mu.RLock() + defer t.mu.RUnlock() + return t.ssrc +} + +// Msid gets the Msid of the track +func (t *RelayTrack) Msid() string { + return t.StreamID() + " " + t.ID() +} + +// Codec gets the Codec of the track +func (t *RelayTrack) Codec() webrtc.RTPCodecParameters { + t.mu.RLock() + defer t.mu.RUnlock() + return t.codec +} + +// Read reads data from the track. +func (t *RelayTrack) Read(b []byte) (n int, attributes interceptor.Attributes, err error) { + return 0, nil, errors.New("relaytrack: not implemented, use ReadRTP instead") +} + +// ReadRTP is a convenience method that wraps Read and unmarshals for you. +func (t *RelayTrack) ReadRTP() (*rtp.Packet, interceptor.Attributes, error) { + p := <-t.rtpChan + return p, nil, nil +} + +// SetReadDeadline sets the max amount of time the RTP stream will block before returning. 0 is forever. +func (t *RelayTrack) SetReadDeadline(deadline time.Time) error { + return errors.New("relaytrack: not implemented") +} diff --git a/remotetrack.go b/remotetrack.go index 8462aae..515188e 100644 --- a/remotetrack.go +++ b/remotetrack.go @@ -11,7 +11,6 @@ import ( "github.com/golang/glog" "github.com/pion/interceptor/pkg/stats" "github.com/pion/rtp" - "github.com/pion/webrtc/v3" ) type remoteTrack struct { @@ -19,7 +18,6 @@ type remoteTrack struct { cancel context.CancelFunc mu sync.Mutex track IRemoteTrack - receiver *webrtc.RTPReceiver onReadCallbacks []func(*rtp.Packet) onPLI func() error bitrate *atomic.Uint32 @@ -32,14 +30,13 @@ type remoteTrack struct { onStatsUpdated func(*stats.Stats) } -func newRemoteTrack(ctx context.Context, track IRemoteTrack, receiver *webrtc.RTPReceiver, pliInterval time.Duration, onPLI func() error, statsGetter stats.Getter, onStatsUpdated func(*stats.Stats)) *remoteTrack { +func newRemoteTrack(ctx context.Context, track IRemoteTrack, pliInterval time.Duration, onPLI func() error, statsGetter stats.Getter, onStatsUpdated func(*stats.Stats)) *remoteTrack { localctx, cancel := context.WithCancel(ctx) rt := &remoteTrack{ context: localctx, cancel: cancel, mu: sync.Mutex{}, track: track, - receiver: receiver, bitrate: &atomic.Uint32{}, previousBytesReceived: &atomic.Uint64{}, currentBytesReceived: &atomic.Uint64{}, diff --git a/scalableclienttrack.go b/scalableclienttrack.go index f4110b3..920bbb3 100644 --- a/scalableclienttrack.go +++ b/scalableclienttrack.go @@ -83,7 +83,7 @@ type scaleabletClientTrack struct { kind webrtc.RTPCodecType mimeType string localTrack *webrtc.TrackLocalStaticRTP - remoteTrack *track + remoteTrack *Track sequenceNumber uint16 lastQuality QualityLevel maxQuality QualityLevel diff --git a/sfu.go b/sfu.go index b955085..eaf5434 100644 --- a/sfu.go +++ b/sfu.go @@ -6,6 +6,7 @@ import ( "time" "github.com/golang/glog" + "github.com/pion/rtp" "github.com/pion/webrtc/v3" "golang.org/x/exp/slices" ) @@ -119,6 +120,7 @@ type SFU struct { onTrackAvailableCallbacks []func(tracks []ITrack) onClientRemovedCallbacks []func(*Client) onClientAddedCallbacks []func(*Client) + relayTracks map[string]ITrack } type PublishedTrack struct { @@ -226,6 +228,11 @@ func (s *SFU) NewClient(id, name string, opts ClientOptions) *Client { } } + // add relay tracks + for _, track := range s.relayTracks { + availableTracks = append(availableTracks, track) + } + if len(availableTracks) > 0 { client.onTracksAvailable(availableTracks) } @@ -298,6 +305,16 @@ func (s *SFU) NewClient(id, name string, opts ClientOptions) *Client { return client } +func (s *SFU) AvailableTracks() []ITrack { + tracks := make([]ITrack, 0) + + for _, client := range s.clients.GetClients() { + tracks = append(tracks, client.publishedTracks.GetTracks()...) + } + + return tracks +} + // Syncs track from connected client to other clients // returns true if need renegotiation func (s *SFU) syncTrack(client *Client) bool { @@ -583,48 +600,42 @@ func (s *SFU) OnTracksAvailable(callback func(tracks []ITrack)) { s.onTrackAvailableCallbacks = append(s.onTrackAvailableCallbacks, callback) } -// func (s *SFU) AddRelayTrack(id, streamid, rid string, kind webrtc.RTPCodecType, ssrc webrtc.SSRC, mimeType string, rtpChan chan *rtp.Packet) error { -// var track ITrack - -// relayTrack := NewTrackRelay(id, streamid, rid, kind, ssrc, mimeType, rtpChan) +func (s *SFU) AddRelayTrack(ctx context.Context, id, streamid, rid, clientid string, kind webrtc.RTPCodecType, ssrc webrtc.SSRC, mimeType string, rtpChan chan *rtp.Packet) error { + var track ITrack -// if rid == "" { -// // not simulcast -// track = newTrack(client, remoteTrack, receiver, vad) -// if err := client.tracks.Add(track); err != nil { -// glog.Error("client: error add track ", err) -// } + relayTrack := NewTrackRelay(id, streamid, rid, kind, ssrc, mimeType, rtpChan) -// client.onTrack(track) -// track.SetAsProcessed() -// } else { -// // simulcast -// var simulcast *simulcastTrack -// var ok bool - -// id := remoteTrack.ID() + onPLI := func() error { + return nil + } -// track, err = client.tracks.Get(id) // not found because the track is not added yet due to race condition + if rid == "" { + // not simulcast + track = newTrack(ctx, clientid, relayTrack, s.pliInterval, onPLI, nil, nil) + s.mu.Lock() + s.relayTracks[relayTrack.ID()] = track + s.mu.Unlock() + } else { + // simulcast + var simulcast *SimulcastTrack + var ok bool -// if err != nil { -// // if track not found, add it -// track = newSimulcastTrack(client, remoteTrack, receiver) -// if err := client.tracks.Add(track); err != nil { -// glog.Error("client: error add track ", err) -// } -// } else if simulcast, ok = track.(*simulcastTrack); ok { -// simulcast.AddRemoteTrack(remoteTrack, receiver) -// } + s.mu.Lock() + track, ok := s.relayTracks[relayTrack.ID()] + if !ok { + // if track not found, add it + track = newSimulcastTrack(ctx, clientid, relayTrack, s.pliInterval, onPLI, nil, nil) + s.relayTracks[relayTrack.ID()] = track + + } else if simulcast, ok = track.(*SimulcastTrack); ok { + simulcast.AddRemoteTrack(relayTrack, nil, nil) + } + s.mu.Unlock() + } -// // // only process track when the highest quality is available -// // simulcast.mu.Lock() -// // isHighAvailable := simulcast.remoteTrackHigh != nil -// // simulcast.mu.Unlock() + s.broadcastTracksToAutoSubscribeClients(clientid, []ITrack{track}) -// if !track.IsProcessed() { -// client.onTrack(track) -// track.SetAsProcessed() -// } + s.onTracksAvailable([]ITrack{track}) -// } -// } + return nil +} diff --git a/simulcastclienttrack.go b/simulcastclienttrack.go index 26c4be9..84e4cbd 100644 --- a/simulcastclienttrack.go +++ b/simulcastclienttrack.go @@ -17,7 +17,7 @@ type simulcastClientTrack struct { kind webrtc.RTPCodecType mimeType string localTrack *webrtc.TrackLocalStaticRTP - remoteTrack *simulcastTrack + remoteTrack *SimulcastTrack lastBlankSequenceNumber *atomic.Uint32 sequenceNumber *atomic.Uint32 lastQuality *atomic.Uint32 @@ -40,11 +40,8 @@ func (t *simulcastClientTrack) isFirstKeyframePacket(p *rtp.Packet) bool { return isKeyframe && t.lastTimestamp.Load() != p.Timestamp } -func (t *simulcastClientTrack) send(p *rtp.Packet, quality QualityLevel, lastQuality QualityLevel, isPaddingPackets bool) { - if !isPaddingPackets { - // set the last processed packet timestamp to identify if is begining of the new frame - t.lastTimestamp.Store(p.Timestamp) - } +func (t *simulcastClientTrack) send(p *rtp.Packet, quality QualityLevel, lastQuality QualityLevel) { + t.lastTimestamp.Store(p.Timestamp) if lastQuality != quality { t.lastQuality.Store(uint32(quality)) @@ -130,33 +127,23 @@ func (t *simulcastClientTrack) push(p *rtp.Packet, quality QualityLevel) { } if trackQuality == quality { - t.send(p, trackQuality, lastQuality, false) + t.send(p, trackQuality, lastQuality) } else if trackQuality == QualityNone && quality == QualityLow { if isFirstKeyframePacket { glog.Warning("clienttrack: no quality level to send") if t.localTrack.Codec().MimeType == webrtc.MimeTypeH264 { // if codec is h264, send a blank frame once p.Payload = getH264BlankFrame() - t.send(p, QualityLow, lastQuality, false) + t.send(p, QualityLow, lastQuality) } else if t.localTrack.Codec().MimeType != webrtc.MimeTypeH264 && t.remoteTrack.isTrackActive(QualityLow) { // if codec is not h264, send a low quality packet - t.send(p, QualityLow, lastQuality, false) + t.send(p, QualityLow, lastQuality) } else { // last effort, send the last quality - t.send(p, lastQuality, lastQuality, false) + t.send(p, lastQuality, lastQuality) } } } - // } else if Uint32ToQualityLevel(t.paddingQuality.Load()) == quality { - // paddingTS := t.paddingTS.Load() - // if p.Timestamp > paddingTS { - // // new frame, reset padding quality - // t.paddingQuality.Store(QualityNone) - // } else if p.Timestamp == paddingTS { - // // padding packet - // t.send(p, quality, quality, true) - // } - // } } diff --git a/track.go b/track.go index efa90ae..10fd077 100644 --- a/track.go +++ b/track.go @@ -58,18 +58,18 @@ type ITrack interface { TotalTracks() int OnEnded(func()) onEnded() + Relay(func(webrtc.SSRC, *rtp.Packet)) } -type track struct { +type Track struct { mu sync.Mutex base baseTrack remoteTrack *remoteTrack onEndedCallbacks []func() - vad *voiceactivedetector.VoiceDetector onReadCallbacks []func(*rtp.Packet, QualityLevel) } -func newTrack(ctx context.Context, clientID string, trackRemote *webrtc.TrackRemote, receiver *webrtc.RTPReceiver, pliInterval time.Duration, onPLI func() error, vad *voiceactivedetector.VoiceDetector, stats stats.Getter, onStatsUpdated func(*stats.Stats)) ITrack { +func newTrack(ctx context.Context, clientID string, trackRemote IRemoteTrack, pliInterval time.Duration, onPLI func() error, stats stats.Getter, onStatsUpdated func(*stats.Stats)) ITrack { ctList := newClientTrackList() baseTrack := baseTrack{ @@ -83,18 +83,17 @@ func newTrack(ctx context.Context, clientID string, trackRemote *webrtc.TrackRem clientTracks: ctList, } - t := &track{ + t := &Track{ mu: sync.Mutex{}, base: baseTrack, - vad: vad, onReadCallbacks: make([]func(*rtp.Packet, QualityLevel), 0), onEndedCallbacks: make([]func(), 0), } - t.remoteTrack = newRemoteTrack(ctx, trackRemote, receiver, pliInterval, onPLI, stats, onStatsUpdated) + t.remoteTrack = newRemoteTrack(ctx, trackRemote, pliInterval, onPLI, stats, onStatsUpdated) t.remoteTrack.OnRead(func(p *rtp.Packet) { // do - tracks := ctList.GetTracks() + tracks := t.base.clientTracks.GetTracks() for _, track := range tracks { track.push(p, QualityHigh) // quality doesn't matter on non simulcast track } @@ -109,11 +108,11 @@ func newTrack(ctx context.Context, clientID string, trackRemote *webrtc.TrackRem return t } -func (t *track) ClientID() string { +func (t *Track) ClientID() string { return t.base.clientid } -func (t *track) createLocalTrack() *webrtc.TrackLocalStaticRTP { +func (t *Track) createLocalTrack() *webrtc.TrackLocalStaticRTP { track, newTrackErr := webrtc.NewTrackLocalStaticRTP(t.remoteTrack.track.Codec().RTPCodecCapability, t.base.id, t.base.streamid) if newTrackErr != nil { panic(newTrackErr) @@ -122,63 +121,71 @@ func (t *track) createLocalTrack() *webrtc.TrackLocalStaticRTP { return track } -func (t *track) ID() string { +func (t *Track) ID() string { return t.base.id } -func (t *track) StreamID() string { +func (t *Track) StreamID() string { return t.base.streamid } -func (t *track) RemoteTrack() *remoteTrack { +func (t *Track) SSRC() webrtc.SSRC { + return t.remoteTrack.track.SSRC() +} + +func (t *Track) RemoteTrack() *remoteTrack { t.mu.Lock() defer t.mu.Unlock() return t.remoteTrack } -func (t *track) IsScreen() bool { +func (t *Track) IsScreen() bool { return t.base.isScreen.Load() } -func (t *track) IsSimulcast() bool { +func (t *Track) IsSimulcast() bool { return false } -func (t *track) IsScaleable() bool { +func (t *Track) IsScaleable() bool { return t.MimeType() == webrtc.MimeTypeVP9 } -func (t *track) IsProcessed() bool { +func (t *Track) IsProcessed() bool { t.mu.Lock() defer t.mu.Unlock() return t.base.isProcessed } -func (t *track) Kind() webrtc.RTPCodecType { +func (t *Track) Kind() webrtc.RTPCodecType { return t.base.kind } -func (t *track) MimeType() string { +func (t *Track) MimeType() string { return t.base.codec.MimeType } -func (t *track) TotalTracks() int { +func (t *Track) SSRCHigh() webrtc.SSRC { + return t.remoteTrack.Track().SSRC() +} + +func (t *Track) SSRCMid() webrtc.SSRC { + return t.remoteTrack.Track().SSRC() +} + +func (t *Track) SSRCLow() webrtc.SSRC { + return t.remoteTrack.Track().SSRC() +} + +func (t *Track) TotalTracks() int { return 1 } -func (t *track) subscribe(c *Client) iClientTrack { +func (t *Track) subscribe(c *Client) iClientTrack { var ct iClientTrack - if t.Kind() == webrtc.RTPCodecTypeAudio && c.IsVADEnabled() { - glog.Info("track: voice activity detector enabled") - t.vad.OnVoiceDetected(func(activity voiceactivedetector.VoiceActivity) { - // send through datachannel - c.onVoiceDetected(activity) - }) - } - if t.MimeType() == webrtc.MimeTypeVP9 { ct = &scaleabletClientTrack{ mu: sync.RWMutex{}, @@ -211,6 +218,15 @@ func (t *track) subscribe(c *Client) iClientTrack { } } + if t.Kind() == webrtc.RTPCodecTypeAudio && c.IsVADEnabled() { + glog.Info("track: voice activity detector enabled") + vad := c.vad.AddAudioTrack(ct.LocalTrack()) + vad.OnVoiceDetected(func(activity voiceactivedetector.VoiceActivity) { + // send through datachannel + c.onVoiceDetected(activity) + }) + } + t.remoteTrack.OnEnded(func() { ct.onTrackEnded() }) @@ -227,18 +243,18 @@ func (t *track) subscribe(c *Client) iClientTrack { return ct } -func (t *track) SetSourceType(sourceType TrackType) { +func (t *Track) SetSourceType(sourceType TrackType) { t.base.isScreen.Store(sourceType == TrackTypeScreen) } -func (t *track) SetAsProcessed() { +func (t *Track) SetAsProcessed() { t.mu.Lock() defer t.mu.Unlock() t.base.isProcessed = true } -func (t *track) onEnded() { +func (t *Track) onEnded() { t.mu.Lock() defer t.mu.Unlock() @@ -247,21 +263,21 @@ func (t *track) onEnded() { } } -func (t *track) OnEnded(callback func()) { +func (t *Track) OnEnded(callback func()) { t.mu.Lock() defer t.mu.Unlock() t.onEndedCallbacks = append(t.onEndedCallbacks, callback) } -func (t *track) OnRead(callback func(*rtp.Packet, QualityLevel)) { +func (t *Track) OnRead(callback func(*rtp.Packet, QualityLevel)) { t.mu.Lock() defer t.mu.Unlock() t.onReadCallbacks = append(t.onReadCallbacks, callback) } -func (t *track) onRead(p *rtp.Packet, quality QualityLevel) { +func (t *Track) onRead(p *rtp.Packet, quality QualityLevel) { t.mu.Lock() defer t.mu.Unlock() @@ -270,12 +286,18 @@ func (t *track) onRead(p *rtp.Packet, quality QualityLevel) { } } -type simulcastTrack struct { +func (t *Track) Relay(f func(webrtc.SSRC, *rtp.Packet)) { + t.OnRead(func(p *rtp.Packet, quality QualityLevel) { + f(t.SSRC(), p) + }) +} + +type SimulcastTrack struct { context context.Context mu sync.Mutex base *baseTrack baseTS uint32 - onTrackComplete func() + onTrackCompleteCallbacks []func() remoteTrackHigh *remoteTrack remoteTrackHighBaseTS uint32 highSequence uint16 @@ -301,8 +323,8 @@ type simulcastTrack struct { onPLI func() error } -func newSimulcastTrack(ctx context.Context, clientid string, track IRemoteTrack, receiver *webrtc.RTPReceiver, pliInterval time.Duration, onPLI func() error, stats stats.Getter, onStatsUpdated func(*stats.Stats)) ITrack { - t := &simulcastTrack{ +func newSimulcastTrack(ctx context.Context, clientid string, track IRemoteTrack, pliInterval time.Duration, onPLI func() error, stats stats.Getter, onStatsUpdated func(*stats.Stats)) ITrack { + t := &SimulcastTrack{ context: ctx, mu: sync.Mutex{}, base: &baseTrack{ @@ -321,6 +343,7 @@ func newSimulcastTrack(ctx context.Context, clientid string, track IRemoteTrack, lastHighKeyframeTS: &atomic.Int64{}, lastMidKeyframeTS: &atomic.Int64{}, lastLowKeyframeTS: &atomic.Int64{}, + onTrackCompleteCallbacks: make([]func(), 0), onAddedRemoteTrackCallbacks: make([]func(*remoteTrack), 0), onEndedCallbacks: make([]func(), 0), onReadCallbacks: make([]func(*rtp.Packet, QualityLevel), 0), @@ -328,23 +351,23 @@ func newSimulcastTrack(ctx context.Context, clientid string, track IRemoteTrack, onPLI: onPLI, } - _ = t.AddRemoteTrack(track, receiver, stats, onStatsUpdated) + _ = t.AddRemoteTrack(track, stats, onStatsUpdated) return t } -func (t *simulcastTrack) ClientID() string { +func (t *SimulcastTrack) ClientID() string { return t.base.clientid } -func (t *simulcastTrack) onRemoteTrackAdded(f func(*remoteTrack)) { +func (t *SimulcastTrack) onRemoteTrackAdded(f func(*remoteTrack)) { t.mu.Lock() defer t.mu.Unlock() t.onAddedRemoteTrackCallbacks = append(t.onAddedRemoteTrackCallbacks, f) } -func (t *simulcastTrack) onRemoteTrackAddedCallbacks(track *remoteTrack) { +func (t *SimulcastTrack) onRemoteTrackAddedCallbacks(track *remoteTrack) { t.mu.Lock() defer t.mu.Unlock() @@ -353,42 +376,51 @@ func (t *simulcastTrack) onRemoteTrackAddedCallbacks(track *remoteTrack) { } } -func (t *simulcastTrack) OnTrackComplete(f func()) { +func (t *SimulcastTrack) OnTrackComplete(f func()) { + t.mu.Lock() + defer t.mu.Unlock() + + t.onTrackCompleteCallbacks = append(t.onTrackCompleteCallbacks, f) +} + +func (t *SimulcastTrack) onTrackComplete() { t.mu.Lock() defer t.mu.Unlock() - t.onTrackComplete = f + for _, f := range t.onTrackCompleteCallbacks { + f() + } } // TODO: this is contain multiple tracks, there is a possibility remote track high is not available yet -func (t *simulcastTrack) ID() string { +func (t *SimulcastTrack) ID() string { return t.base.id } -func (t *simulcastTrack) StreamID() string { +func (t *SimulcastTrack) StreamID() string { return t.base.streamid } -func (t *simulcastTrack) IsSimulcast() bool { +func (t *SimulcastTrack) IsSimulcast() bool { return true } -func (t *simulcastTrack) IsScaleable() bool { +func (t *SimulcastTrack) IsScaleable() bool { return false } -func (t *simulcastTrack) IsProcessed() bool { +func (t *SimulcastTrack) IsProcessed() bool { t.mu.Lock() defer t.mu.Unlock() return t.base.isProcessed } -func (t *simulcastTrack) Kind() webrtc.RTPCodecType { +func (t *SimulcastTrack) Kind() webrtc.RTPCodecType { return t.base.kind } -func (t *simulcastTrack) AddRemoteTrack(track IRemoteTrack, receiver *webrtc.RTPReceiver, stats stats.Getter, onStatsUpdated func(*stats.Stats)) *remoteTrack { +func (t *SimulcastTrack) AddRemoteTrack(track IRemoteTrack, stats stats.Getter, onStatsUpdated func(*stats.Stats)) *remoteTrack { var remoteTrack *remoteTrack quality := RIDToQuality(track.RID()) @@ -434,12 +466,14 @@ func (t *simulcastTrack) AddRemoteTrack(track IRemoteTrack, receiver *webrtc.RTP t.mu.Lock() - remoteTrack = newRemoteTrack(t.context, track, receiver, t.pliInterval, t.onPLI, stats, onStatsUpdated) + remoteTrack = newRemoteTrack(t.context, track, t.pliInterval, t.onPLI, stats, onStatsUpdated) + remoteTrack.OnRead(onRead) + switch quality { case QualityHigh: - t.remoteTrackHigh = remoteTrack + remoteTrack.OnEnded(func() { t.mu.Lock() defer t.mu.Unlock() @@ -452,6 +486,7 @@ func (t *simulcastTrack) AddRemoteTrack(track IRemoteTrack, receiver *webrtc.RTP case QualityMid: t.remoteTrackMid = remoteTrack + remoteTrack.OnEnded(func() { t.mu.Lock() defer t.mu.Unlock() @@ -461,6 +496,7 @@ func (t *simulcastTrack) AddRemoteTrack(track IRemoteTrack, receiver *webrtc.RTP t.onEnded() } }) + case QualityLow: t.remoteTrackLow = remoteTrack remoteTrack.OnEnded(func() { @@ -478,7 +514,7 @@ func (t *simulcastTrack) AddRemoteTrack(track IRemoteTrack, receiver *webrtc.RTP } // check if all simulcast tracks are available - if t.onTrackComplete != nil && t.remoteTrackHigh != nil && t.remoteTrackMid != nil && t.remoteTrackLow != nil { + if t.remoteTrackHigh != nil && t.remoteTrackMid != nil && t.remoteTrackLow != nil { t.onTrackComplete() } @@ -489,7 +525,7 @@ func (t *simulcastTrack) AddRemoteTrack(track IRemoteTrack, receiver *webrtc.RTP return remoteTrack } -func (t *simulcastTrack) getRemoteTrack(q QualityLevel) *remoteTrack { +func (t *SimulcastTrack) getRemoteTrack(q QualityLevel) *remoteTrack { t.mu.Lock() defer t.mu.Unlock() @@ -505,7 +541,7 @@ func (t *simulcastTrack) getRemoteTrack(q QualityLevel) *remoteTrack { return nil } -func (t *simulcastTrack) subscribe(client *Client) iClientTrack { +func (t *SimulcastTrack) subscribe(client *Client) iClientTrack { // Create a local track, all our SFU clients will be fed via this track track, newTrackErr := webrtc.NewTrackLocalStaticRTP(t.base.codec.RTPCodecCapability, t.base.id, t.base.streamid) if newTrackErr != nil { @@ -572,22 +608,26 @@ func (t *simulcastTrack) subscribe(client *Client) iClientTrack { return ct } -func (t *simulcastTrack) SetSourceType(sourceType TrackType) { +func (t *SimulcastTrack) SetSourceType(sourceType TrackType) { t.base.isScreen.Store(sourceType == TrackTypeScreen) } -func (t *simulcastTrack) SetAsProcessed() { +func (t *SimulcastTrack) SetAsProcessed() { t.mu.Lock() defer t.mu.Unlock() t.base.isProcessed = true } -func (t *simulcastTrack) IsScreen() bool { +func (t *SimulcastTrack) IsScreen() bool { return t.base.isScreen.Load() } -func (t *simulcastTrack) TotalTracks() int { +func (t *SimulcastTrack) IsTrackComplete() bool { + return t.TotalTracks() == 3 +} + +func (t *SimulcastTrack) TotalTracks() int { t.mu.Lock() defer t.mu.Unlock() @@ -608,7 +648,7 @@ func (t *simulcastTrack) TotalTracks() int { } // track is considered active if the track is not nil and the latest read operation was 500ms ago -func (t *simulcastTrack) isTrackActive(quality QualityLevel) bool { +func (t *SimulcastTrack) isTrackActive(quality QualityLevel) bool { t.mu.Lock() defer t.mu.Unlock() @@ -661,7 +701,7 @@ func (t *simulcastTrack) isTrackActive(quality QualityLevel) bool { return false } -func (t *simulcastTrack) sendPLI(quality QualityLevel) { +func (t *SimulcastTrack) sendPLI(quality QualityLevel) { switch quality { case QualityHigh: if t.remoteTrackHigh != nil { @@ -684,18 +724,18 @@ func (t *simulcastTrack) sendPLI(quality QualityLevel) { } } -func (t *simulcastTrack) MimeType() string { +func (t *SimulcastTrack) MimeType() string { return t.base.codec.MimeType } -func (t *simulcastTrack) OnEnded(f func()) { +func (t *SimulcastTrack) OnEnded(f func()) { t.mu.Lock() defer t.mu.Unlock() t.onEndedCallbacks = append(t.onEndedCallbacks, f) } -func (t *simulcastTrack) onEnded() { +func (t *SimulcastTrack) onEnded() { t.mu.Lock() defer t.mu.Unlock() @@ -704,14 +744,14 @@ func (t *simulcastTrack) onEnded() { } } -func (t *simulcastTrack) OnRead(callback func(*rtp.Packet, QualityLevel)) { +func (t *SimulcastTrack) OnRead(callback func(*rtp.Packet, QualityLevel)) { t.mu.Lock() defer t.mu.Unlock() t.onReadCallbacks = append(t.onReadCallbacks, callback) } -func (t *simulcastTrack) onRead(p *rtp.Packet, quality QualityLevel) { +func (t *SimulcastTrack) onRead(p *rtp.Packet, quality QualityLevel) { t.mu.Lock() defer t.mu.Unlock() @@ -720,6 +760,67 @@ func (t *simulcastTrack) onRead(p *rtp.Packet, quality QualityLevel) { } } +func (t *SimulcastTrack) SSRCHigh() webrtc.SSRC { + if t.remoteTrackHigh == nil { + return 0 + } + + return t.remoteTrackHigh.Track().SSRC() +} + +func (t *SimulcastTrack) SSRCMid() webrtc.SSRC { + if t.remoteTrackMid == nil { + return 0 + } + + return t.remoteTrackMid.Track().SSRC() +} + +func (t *SimulcastTrack) SSRCLow() webrtc.SSRC { + if t.remoteTrackLow == nil { + return 0 + } + + return t.remoteTrackLow.Track().SSRC() +} + +func (t *SimulcastTrack) RIDHigh() string { + if t.remoteTrackHigh == nil { + return "" + } + + return t.remoteTrackHigh.track.RID() +} + +func (t *SimulcastTrack) RIDMid() string { + if t.remoteTrackMid == nil { + return "" + } + + return t.remoteTrackMid.track.RID() +} + +func (t *SimulcastTrack) RIDLow() string { + if t.remoteTrackLow == nil { + return "" + } + + return t.remoteTrackLow.track.RID() +} + +func (t *SimulcastTrack) Relay(f func(webrtc.SSRC, *rtp.Packet)) { + t.OnRead(func(p *rtp.Packet, quality QualityLevel) { + switch quality { + case QualityHigh: + f(t.SSRCHigh(), p) + case QualityMid: + f(t.SSRCMid(), p) + case QualityLow: + f(t.SSRCLow(), p) + } + }) +} + type SubscribeTrackRequest struct { ClientID string `json:"client_id"` StreamID string `json:"stream_id"` @@ -758,8 +859,8 @@ func (t *trackList) Add(track ITrack) error { } func (t *trackList) Get(ID string) (ITrack, error) { - // t.mu.Lock() - // defer t.mu.Unlock() + t.mu.Lock() + defer t.mu.Unlock() if track, ok := t.tracks[ID]; ok { return track, nil From 7de1bd6eea50b0d1d30f0cd2cfb61add27340437 Mon Sep 17 00:00:00 2001 From: Yohan Totting Date: Wed, 22 Nov 2023 08:36:59 +0700 Subject: [PATCH 03/13] fix typo scaleabletrack --- scalableclienttrack.go | 46 +++++++++++++++++++++--------------------- track.go | 2 +- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/scalableclienttrack.go b/scalableclienttrack.go index 920bbb3..6a958b3 100644 --- a/scalableclienttrack.go +++ b/scalableclienttrack.go @@ -76,7 +76,7 @@ func DefaultQualityPreset() QualityPreset { } } -type scaleabletClientTrack struct { +type scaleableClientTrack struct { id string mu sync.RWMutex client *Client @@ -99,14 +99,14 @@ type scaleabletClientTrack struct { qualityPreset QualityPreset } -func (t *scaleabletClientTrack) Client() *Client { +func (t *scaleableClientTrack) Client() *Client { t.mu.Lock() defer t.mu.Unlock() return t.client } -func (t *scaleabletClientTrack) writeRTP(p *rtp.Packet) { +func (t *scaleableClientTrack) writeRTP(p *rtp.Packet) { t.lastTimestamp = p.Timestamp t.sequenceNumber = p.SequenceNumber @@ -115,7 +115,7 @@ func (t *scaleabletClientTrack) writeRTP(p *rtp.Packet) { } } -func (t *scaleabletClientTrack) isKeyframe(vp9 *codecs.VP9Packet) bool { +func (t *scaleableClientTrack) isKeyframe(vp9 *codecs.VP9Packet) bool { if len(vp9.Payload) < 1 { return false } @@ -136,7 +136,7 @@ func (t *scaleabletClientTrack) isKeyframe(vp9 *codecs.VP9Packet) bool { // this where the temporal and spatial layers are will be decided to be sent to the client or not // compare it with the claimed quality to decide if the packet should be sent or not -func (t *scaleabletClientTrack) push(p *rtp.Packet, _ QualityLevel) { +func (t *scaleableClientTrack) push(p *rtp.Packet, _ QualityLevel) { var qualityPreset IQualityPreset vp9Packet := &codecs.VP9Packet{} @@ -211,16 +211,16 @@ func (t *scaleabletClientTrack) push(p *rtp.Packet, _ QualityLevel) { t.send(p) } -func (t *scaleabletClientTrack) send(p *rtp.Packet) { +func (t *scaleableClientTrack) send(p *rtp.Packet) { p.SequenceNumber = p.SequenceNumber - t.dropCounter t.writeRTP(p) } -func (t *scaleabletClientTrack) RemoteTrack() *remoteTrack { +func (t *scaleableClientTrack) RemoteTrack() *remoteTrack { return t.remoteTrack.remoteTrack } -func (t *scaleabletClientTrack) getCurrentBitrate() uint32 { +func (t *scaleableClientTrack) getCurrentBitrate() uint32 { currentTrack := t.RemoteTrack() if currentTrack == nil { return 0 @@ -229,47 +229,47 @@ func (t *scaleabletClientTrack) getCurrentBitrate() uint32 { return currentTrack.GetCurrentBitrate() } -func (t *scaleabletClientTrack) ID() string { +func (t *scaleableClientTrack) ID() string { return t.id } -func (t *scaleabletClientTrack) Kind() webrtc.RTPCodecType { +func (t *scaleableClientTrack) Kind() webrtc.RTPCodecType { return t.kind } -func (t *scaleabletClientTrack) LocalTrack() *webrtc.TrackLocalStaticRTP { +func (t *scaleableClientTrack) LocalTrack() *webrtc.TrackLocalStaticRTP { return t.localTrack } -func (t *scaleabletClientTrack) IsScreen() bool { +func (t *scaleableClientTrack) IsScreen() bool { return t.isScreen } -func (t *scaleabletClientTrack) SetSourceType(sourceType TrackType) { +func (t *scaleableClientTrack) SetSourceType(sourceType TrackType) { t.isScreen = (sourceType == TrackTypeScreen) } -func (t *scaleabletClientTrack) SetLastQuality(quality QualityLevel) { +func (t *scaleableClientTrack) SetLastQuality(quality QualityLevel) { t.mu.Lock() defer t.mu.Unlock() t.lastQuality = quality } -func (t *scaleabletClientTrack) LastQuality() QualityLevel { +func (t *scaleableClientTrack) LastQuality() QualityLevel { t.mu.Lock() defer t.mu.Unlock() return QualityLevel(t.lastQuality) } -func (t *scaleabletClientTrack) OnTrackEnded(callback func()) { +func (t *scaleableClientTrack) OnTrackEnded(callback func()) { t.mu.Lock() defer t.mu.Unlock() t.onTrackEndedCallbacks = append(t.onTrackEndedCallbacks, callback) } -func (t *scaleabletClientTrack) onTrackEnded() { +func (t *scaleableClientTrack) onTrackEnded() { if t.isEnded { return } @@ -281,33 +281,33 @@ func (t *scaleabletClientTrack) onTrackEnded() { t.isEnded = true } -func (t *scaleabletClientTrack) SetMaxQuality(quality QualityLevel) { +func (t *scaleableClientTrack) SetMaxQuality(quality QualityLevel) { t.mu.Lock() defer t.mu.Unlock() t.maxQuality = quality } -func (t *scaleabletClientTrack) MaxQuality() QualityLevel { +func (t *scaleableClientTrack) MaxQuality() QualityLevel { t.mu.Lock() defer t.mu.Unlock() return t.maxQuality } -func (t *scaleabletClientTrack) IsSimulcast() bool { +func (t *scaleableClientTrack) IsSimulcast() bool { return false } -func (t *scaleabletClientTrack) IsScaleable() bool { +func (t *scaleableClientTrack) IsScaleable() bool { return true } -func (t *scaleabletClientTrack) RequestPLI() { +func (t *scaleableClientTrack) RequestPLI() { t.remoteTrack.remoteTrack.sendPLI() } -func (t *scaleabletClientTrack) getQuality() QualityLevel { +func (t *scaleableClientTrack) getQuality() QualityLevel { claim := t.client.bitrateController.GetClaim(t.ID()) if claim == nil { diff --git a/track.go b/track.go index 10fd077..ea92a0d 100644 --- a/track.go +++ b/track.go @@ -187,7 +187,7 @@ func (t *Track) subscribe(c *Client) iClientTrack { var ct iClientTrack if t.MimeType() == webrtc.MimeTypeVP9 { - ct = &scaleabletClientTrack{ + ct = &scaleableClientTrack{ mu: sync.RWMutex{}, id: t.base.id, kind: t.base.kind, From fbad034c79f35cc52e2ceebff3ed591a7c70aef7 Mon Sep 17 00:00:00 2001 From: Yohan Totting Date: Wed, 22 Nov 2023 09:28:39 +0700 Subject: [PATCH 04/13] remove callback read on remotetrack --- bitratecontroller.go | 2 +- remotetrack.go | 26 +++++--------------------- track.go | 44 ++++++++++++++++++++++---------------------- 3 files changed, 28 insertions(+), 44 deletions(-) diff --git a/bitratecontroller.go b/bitratecontroller.go index c8fe588..1e8cfe1 100644 --- a/bitratecontroller.go +++ b/bitratecontroller.go @@ -298,7 +298,7 @@ func (bc *bitrateController) addClaims(clientTracks []iClientTrack) error { if clientTrack.IsSimulcast() { clientTrack.(*simulcastClientTrack).lastQuality.Store(uint32(trackQuality)) } else if clientTrack.IsScaleable() { - clientTrack.(*scaleabletClientTrack).lastQuality = trackQuality + clientTrack.(*scaleableClientTrack).lastQuality = trackQuality } _, err := bc.addClaim(clientTrack, trackQuality, true) diff --git a/remotetrack.go b/remotetrack.go index 515188e..54f6d7e 100644 --- a/remotetrack.go +++ b/remotetrack.go @@ -18,7 +18,7 @@ type remoteTrack struct { cancel context.CancelFunc mu sync.Mutex track IRemoteTrack - onReadCallbacks []func(*rtp.Packet) + onRead func(*rtp.Packet) onPLI func() error bitrate *atomic.Uint32 previousBytesReceived *atomic.Uint64 @@ -30,7 +30,7 @@ type remoteTrack struct { onStatsUpdated func(*stats.Stats) } -func newRemoteTrack(ctx context.Context, track IRemoteTrack, pliInterval time.Duration, onPLI func() error, statsGetter stats.Getter, onStatsUpdated func(*stats.Stats)) *remoteTrack { +func newRemoteTrack(ctx context.Context, track IRemoteTrack, pliInterval time.Duration, onPLI func() error, statsGetter stats.Getter, onStatsUpdated func(*stats.Stats), onRead func(*rtp.Packet)) *remoteTrack { localctx, cancel := context.WithCancel(ctx) rt := &remoteTrack{ context: localctx, @@ -45,7 +45,7 @@ func newRemoteTrack(ctx context.Context, track IRemoteTrack, pliInterval time.Du statsGetter: statsGetter, onStatsUpdated: onStatsUpdated, onPLI: onPLI, - onReadCallbacks: make([]func(*rtp.Packet), 0), + onRead: onRead, } rt.enableIntervalPLI(pliInterval) @@ -55,24 +55,6 @@ func newRemoteTrack(ctx context.Context, track IRemoteTrack, pliInterval time.Du return rt } -func (t *remoteTrack) OnRead(f func(*rtp.Packet)) { - t.mu.Lock() - defer t.mu.Unlock() - - t.onReadCallbacks = append(t.onReadCallbacks, f) -} - -func (t *remoteTrack) onRead(p *rtp.Packet) { - t.mu.Lock() - defer t.mu.Unlock() - - for _, f := range t.onReadCallbacks { - f(p) - } - - go t.updateStats() -} - func (t *remoteTrack) OnEnded(f func()) { t.mu.Lock() defer t.mu.Unlock() @@ -107,6 +89,8 @@ func (t *remoteTrack) readRTP() { } t.onRead(rtp) + + go t.updateStats() } } }() diff --git a/track.go b/track.go index ea92a0d..4dd002c 100644 --- a/track.go +++ b/track.go @@ -51,14 +51,14 @@ type ITrack interface { IsProcessed() bool SetSourceType(TrackType) SetAsProcessed() - OnRead(func(*rtp.Packet, QualityLevel)) + OnRead(func(rtp.Packet, QualityLevel)) IsScreen() bool Kind() webrtc.RTPCodecType MimeType() string TotalTracks() int OnEnded(func()) onEnded() - Relay(func(webrtc.SSRC, *rtp.Packet)) + Relay(func(webrtc.SSRC, rtp.Packet)) } type Track struct { @@ -66,7 +66,7 @@ type Track struct { base baseTrack remoteTrack *remoteTrack onEndedCallbacks []func() - onReadCallbacks []func(*rtp.Packet, QualityLevel) + onReadCallbacks []func(rtp.Packet, QualityLevel) } func newTrack(ctx context.Context, clientID string, trackRemote IRemoteTrack, pliInterval time.Duration, onPLI func() error, stats stats.Getter, onStatsUpdated func(*stats.Stats)) ITrack { @@ -86,12 +86,11 @@ func newTrack(ctx context.Context, clientID string, trackRemote IRemoteTrack, pl t := &Track{ mu: sync.Mutex{}, base: baseTrack, - onReadCallbacks: make([]func(*rtp.Packet, QualityLevel), 0), + onReadCallbacks: make([]func(rtp.Packet, QualityLevel), 0), onEndedCallbacks: make([]func(), 0), } - t.remoteTrack = newRemoteTrack(ctx, trackRemote, pliInterval, onPLI, stats, onStatsUpdated) - t.remoteTrack.OnRead(func(p *rtp.Packet) { + onRead := func(p *rtp.Packet) { // do tracks := t.base.clientTracks.GetTracks() for _, track := range tracks { @@ -99,7 +98,9 @@ func newTrack(ctx context.Context, clientID string, trackRemote IRemoteTrack, pl } t.onRead(p, QualityHigh) - }) + } + + t.remoteTrack = newRemoteTrack(ctx, trackRemote, pliInterval, onPLI, stats, onStatsUpdated, onRead) t.remoteTrack.OnEnded(func() { t.onEnded() @@ -270,7 +271,7 @@ func (t *Track) OnEnded(callback func()) { t.onEndedCallbacks = append(t.onEndedCallbacks, callback) } -func (t *Track) OnRead(callback func(*rtp.Packet, QualityLevel)) { +func (t *Track) OnRead(callback func(rtp.Packet, QualityLevel)) { t.mu.Lock() defer t.mu.Unlock() @@ -278,16 +279,17 @@ func (t *Track) OnRead(callback func(*rtp.Packet, QualityLevel)) { } func (t *Track) onRead(p *rtp.Packet, quality QualityLevel) { - t.mu.Lock() - defer t.mu.Unlock() + // t.mu.Lock() + // defer t.mu.Unlock() for _, callback := range t.onReadCallbacks { - callback(p, quality) + pClone := *p + callback(pClone, quality) } } -func (t *Track) Relay(f func(webrtc.SSRC, *rtp.Packet)) { - t.OnRead(func(p *rtp.Packet, quality QualityLevel) { +func (t *Track) Relay(f func(webrtc.SSRC, rtp.Packet)) { + t.OnRead(func(p rtp.Packet, quality QualityLevel) { f(t.SSRC(), p) }) } @@ -318,7 +320,7 @@ type SimulcastTrack struct { lastLowKeyframeTS *atomic.Int64 onAddedRemoteTrackCallbacks []func(*remoteTrack) onEndedCallbacks []func() - onReadCallbacks []func(*rtp.Packet, QualityLevel) + onReadCallbacks []func(rtp.Packet, QualityLevel) pliInterval time.Duration onPLI func() error } @@ -346,7 +348,7 @@ func newSimulcastTrack(ctx context.Context, clientid string, track IRemoteTrack, onTrackCompleteCallbacks: make([]func(), 0), onAddedRemoteTrackCallbacks: make([]func(*remoteTrack), 0), onEndedCallbacks: make([]func(), 0), - onReadCallbacks: make([]func(*rtp.Packet, QualityLevel), 0), + onReadCallbacks: make([]func(rtp.Packet, QualityLevel), 0), pliInterval: pliInterval, onPLI: onPLI, } @@ -466,9 +468,7 @@ func (t *SimulcastTrack) AddRemoteTrack(track IRemoteTrack, stats stats.Getter, t.mu.Lock() - remoteTrack = newRemoteTrack(t.context, track, t.pliInterval, t.onPLI, stats, onStatsUpdated) - - remoteTrack.OnRead(onRead) + remoteTrack = newRemoteTrack(t.context, track, t.pliInterval, t.onPLI, stats, onStatsUpdated, onRead) switch quality { case QualityHigh: @@ -744,7 +744,7 @@ func (t *SimulcastTrack) onEnded() { } } -func (t *SimulcastTrack) OnRead(callback func(*rtp.Packet, QualityLevel)) { +func (t *SimulcastTrack) OnRead(callback func(rtp.Packet, QualityLevel)) { t.mu.Lock() defer t.mu.Unlock() @@ -756,7 +756,7 @@ func (t *SimulcastTrack) onRead(p *rtp.Packet, quality QualityLevel) { defer t.mu.Unlock() for _, callback := range t.onReadCallbacks { - callback(p, quality) + callback(*p, quality) } } @@ -808,8 +808,8 @@ func (t *SimulcastTrack) RIDLow() string { return t.remoteTrackLow.track.RID() } -func (t *SimulcastTrack) Relay(f func(webrtc.SSRC, *rtp.Packet)) { - t.OnRead(func(p *rtp.Packet, quality QualityLevel) { +func (t *SimulcastTrack) Relay(f func(webrtc.SSRC, rtp.Packet)) { + t.OnRead(func(p rtp.Packet, quality QualityLevel) { switch quality { case QualityHigh: f(t.SSRCHigh(), p) From 28f600b3af510e483814e3afd0bcc76c0fc96ede Mon Sep 17 00:00:00 2001 From: Yohan Totting Date: Wed, 22 Nov 2023 10:12:13 +0700 Subject: [PATCH 05/13] pass by value for packet --- client.go | 4 ++-- clienttrack.go | 6 +++--- docs/extension.md | 5 +++++ remotetrack.go | 6 +++--- scalableclienttrack.go | 8 ++++---- simulcastclienttrack.go | 16 +++++++++------- track.go | 33 ++++++++++++++++----------------- util.go | 4 ++-- 8 files changed, 44 insertions(+), 38 deletions(-) create mode 100644 docs/extension.md diff --git a/client.go b/client.go index 721dfa5..e2609f9 100644 --- a/client.go +++ b/client.go @@ -385,9 +385,9 @@ func NewClient(s *SFU, id string, name string, peerConnectionConfig webrtc.Confi simulcast.AddRemoteTrack(remoteTrack, client.statsGetter, onStatsUpdated) } - // // only process track when the highest quality is available + // // only process track when the lowest quality is available // simulcast.mu.Lock() - // isHighAvailable := simulcast.remoteTrackHigh != nil + // isLowAvailable := simulcast.remoteTrackLow != nil // simulcast.mu.Unlock() if !track.IsProcessed() { diff --git a/clienttrack.go b/clienttrack.go index 07abb75..5e7744f 100644 --- a/clienttrack.go +++ b/clienttrack.go @@ -10,7 +10,7 @@ import ( ) type iClientTrack interface { - push(rtp *rtp.Packet, quality QualityLevel) + push(rtp rtp.Packet, quality QualityLevel) ID() string Kind() webrtc.RTPCodecType LocalTrack() *webrtc.TrackLocalStaticRTP @@ -50,7 +50,7 @@ func (t *clientTrack) Kind() webrtc.RTPCodecType { return t.remoteTrack.track.Kind() } -func (t *clientTrack) push(rtp *rtp.Packet, quality QualityLevel) { +func (t *clientTrack) push(rtp rtp.Packet, quality QualityLevel) { if t.client.peerConnection.PC().ConnectionState() != webrtc.PeerConnectionStateConnected { return } @@ -59,7 +59,7 @@ func (t *clientTrack) push(rtp *rtp.Packet, quality QualityLevel) { // do something here with audio level } - if err := t.localTrack.WriteRTP(rtp); err != nil { + if err := t.localTrack.WriteRTP(&rtp); err != nil { glog.Error("clienttrack: error on write rtp", err) } } diff --git a/docs/extension.md b/docs/extension.md new file mode 100644 index 0000000..f27c6cc --- /dev/null +++ b/docs/extension.md @@ -0,0 +1,5 @@ +# Developing an extension package +An extension can add more features to the SFU without need to modify the SFU code. + +## How it works +The extension can be develop by utilizing the events from SFU components. Each component has its own events. \ No newline at end of file diff --git a/remotetrack.go b/remotetrack.go index 54f6d7e..846d3c2 100644 --- a/remotetrack.go +++ b/remotetrack.go @@ -18,7 +18,7 @@ type remoteTrack struct { cancel context.CancelFunc mu sync.Mutex track IRemoteTrack - onRead func(*rtp.Packet) + onRead func(rtp.Packet) onPLI func() error bitrate *atomic.Uint32 previousBytesReceived *atomic.Uint64 @@ -30,7 +30,7 @@ type remoteTrack struct { onStatsUpdated func(*stats.Stats) } -func newRemoteTrack(ctx context.Context, track IRemoteTrack, pliInterval time.Duration, onPLI func() error, statsGetter stats.Getter, onStatsUpdated func(*stats.Stats), onRead func(*rtp.Packet)) *remoteTrack { +func newRemoteTrack(ctx context.Context, track IRemoteTrack, pliInterval time.Duration, onPLI func() error, statsGetter stats.Getter, onStatsUpdated func(*stats.Stats), onRead func(rtp.Packet)) *remoteTrack { localctx, cancel := context.WithCancel(ctx) rt := &remoteTrack{ context: localctx, @@ -88,7 +88,7 @@ func (t *remoteTrack) readRTP() { return } - t.onRead(rtp) + t.onRead(*rtp) go t.updateStats() } diff --git a/scalableclienttrack.go b/scalableclienttrack.go index 6a958b3..d754da8 100644 --- a/scalableclienttrack.go +++ b/scalableclienttrack.go @@ -106,11 +106,11 @@ func (t *scaleableClientTrack) Client() *Client { return t.client } -func (t *scaleableClientTrack) writeRTP(p *rtp.Packet) { +func (t *scaleableClientTrack) writeRTP(p rtp.Packet) { t.lastTimestamp = p.Timestamp t.sequenceNumber = p.SequenceNumber - if err := t.localTrack.WriteRTP(p); err != nil { + if err := t.localTrack.WriteRTP(&p); err != nil { glog.Error("track: error on write rtp", err) } } @@ -136,7 +136,7 @@ func (t *scaleableClientTrack) isKeyframe(vp9 *codecs.VP9Packet) bool { // this where the temporal and spatial layers are will be decided to be sent to the client or not // compare it with the claimed quality to decide if the packet should be sent or not -func (t *scaleableClientTrack) push(p *rtp.Packet, _ QualityLevel) { +func (t *scaleableClientTrack) push(p rtp.Packet, _ QualityLevel) { var qualityPreset IQualityPreset vp9Packet := &codecs.VP9Packet{} @@ -211,7 +211,7 @@ func (t *scaleableClientTrack) push(p *rtp.Packet, _ QualityLevel) { t.send(p) } -func (t *scaleableClientTrack) send(p *rtp.Packet) { +func (t *scaleableClientTrack) send(p rtp.Packet) { p.SequenceNumber = p.SequenceNumber - t.dropCounter t.writeRTP(p) } diff --git a/simulcastclienttrack.go b/simulcastclienttrack.go index 84e4cbd..ac4ed94 100644 --- a/simulcastclienttrack.go +++ b/simulcastclienttrack.go @@ -34,32 +34,32 @@ func (t *simulcastClientTrack) Client() *Client { return t.client } -func (t *simulcastClientTrack) isFirstKeyframePacket(p *rtp.Packet) bool { +func (t *simulcastClientTrack) isFirstKeyframePacket(p rtp.Packet) bool { isKeyframe := IsKeyframe(t.mimeType, p) return isKeyframe && t.lastTimestamp.Load() != p.Timestamp } -func (t *simulcastClientTrack) send(p *rtp.Packet, quality QualityLevel, lastQuality QualityLevel) { +func (t *simulcastClientTrack) send(p rtp.Packet, quality QualityLevel, lastQuality QualityLevel) { t.lastTimestamp.Store(p.Timestamp) if lastQuality != quality { t.lastQuality.Store(uint32(quality)) } - t.rewritePacket(p, quality) + p = t.rewritePacket(p, quality) t.writeRTP(p) } -func (t *simulcastClientTrack) writeRTP(p *rtp.Packet) { - if err := t.localTrack.WriteRTP(p); err != nil { +func (t *simulcastClientTrack) writeRTP(p rtp.Packet) { + if err := t.localTrack.WriteRTP(&p); err != nil { glog.Error("track: error on write rtp", err) } } -func (t *simulcastClientTrack) push(p *rtp.Packet, quality QualityLevel) { +func (t *simulcastClientTrack) push(p rtp.Packet, quality QualityLevel) { var trackQuality QualityLevel lastQuality := t.LastQuality() @@ -245,7 +245,7 @@ func (t *simulcastClientTrack) IsScaleable() bool { return false } -func (t *simulcastClientTrack) rewritePacket(p *rtp.Packet, quality QualityLevel) { +func (t *simulcastClientTrack) rewritePacket(p rtp.Packet, quality QualityLevel) rtp.Packet { t.remoteTrack.mu.Lock() defer t.remoteTrack.mu.Unlock() // make sure the timestamp and sequence number is consistent from the previous packet even it is not the same track @@ -265,6 +265,8 @@ func (t *simulcastClientTrack) rewritePacket(p *rtp.Packet, quality QualityLevel t.sequenceNumber.Add(uint32(sequenceDelta)) p.SequenceNumber = uint16(t.sequenceNumber.Load()) + + return p } func (t *simulcastClientTrack) RequestPLI() { diff --git a/track.go b/track.go index 4dd002c..044ca30 100644 --- a/track.go +++ b/track.go @@ -90,7 +90,7 @@ func newTrack(ctx context.Context, clientID string, trackRemote IRemoteTrack, pl onEndedCallbacks: make([]func(), 0), } - onRead := func(p *rtp.Packet) { + onRead := func(p rtp.Packet) { // do tracks := t.base.clientTracks.GetTracks() for _, track := range tracks { @@ -278,13 +278,12 @@ func (t *Track) OnRead(callback func(rtp.Packet, QualityLevel)) { t.onReadCallbacks = append(t.onReadCallbacks, callback) } -func (t *Track) onRead(p *rtp.Packet, quality QualityLevel) { +func (t *Track) onRead(p rtp.Packet, quality QualityLevel) { // t.mu.Lock() // defer t.mu.Unlock() for _, callback := range t.onReadCallbacks { - pClone := *p - callback(pClone, quality) + callback(p, quality) } } @@ -427,7 +426,7 @@ func (t *SimulcastTrack) AddRemoteTrack(track IRemoteTrack, stats stats.Getter, quality := RIDToQuality(track.RID()) - onRead := func(p *rtp.Packet) { + onRead := func(p rtp.Packet) { // set the base timestamp for the track if it is not set yet if t.baseTS == 0 { t.baseTS = p.Timestamp @@ -466,18 +465,18 @@ func (t *SimulcastTrack) AddRemoteTrack(track IRemoteTrack, stats stats.Getter, t.onRead(p, quality) } - t.mu.Lock() - remoteTrack = newRemoteTrack(t.context, track, t.pliInterval, t.onPLI, stats, onStatsUpdated, onRead) switch quality { case QualityHigh: + t.mu.Lock() t.remoteTrackHigh = remoteTrack + t.mu.Unlock() remoteTrack.OnEnded(func() { t.mu.Lock() - defer t.mu.Unlock() t.remoteTrackHigh = nil + t.mu.Unlock() if t.remoteTrackHigh == nil && t.remoteTrackMid == nil && t.remoteTrackLow == nil { t.onEnded() @@ -485,12 +484,14 @@ func (t *SimulcastTrack) AddRemoteTrack(track IRemoteTrack, stats stats.Getter, }) case QualityMid: + t.mu.Lock() t.remoteTrackMid = remoteTrack + t.mu.Unlock() remoteTrack.OnEnded(func() { t.mu.Lock() - defer t.mu.Unlock() t.remoteTrackMid = nil + t.mu.Unlock() if t.remoteTrackHigh == nil && t.remoteTrackMid == nil && t.remoteTrackLow == nil { t.onEnded() @@ -498,11 +499,14 @@ func (t *SimulcastTrack) AddRemoteTrack(track IRemoteTrack, stats stats.Getter, }) case QualityLow: + t.mu.Lock() t.remoteTrackLow = remoteTrack + t.mu.Unlock() + remoteTrack.OnEnded(func() { t.mu.Lock() - defer t.mu.Unlock() t.remoteTrackLow = nil + t.mu.Unlock() if t.remoteTrackHigh == nil && t.remoteTrackMid == nil && t.remoteTrackLow == nil { t.onEnded() @@ -518,8 +522,6 @@ func (t *SimulcastTrack) AddRemoteTrack(track IRemoteTrack, stats stats.Getter, t.onTrackComplete() } - t.mu.Unlock() - t.onRemoteTrackAddedCallbacks(remoteTrack) return remoteTrack @@ -751,12 +753,9 @@ func (t *SimulcastTrack) OnRead(callback func(rtp.Packet, QualityLevel)) { t.onReadCallbacks = append(t.onReadCallbacks, callback) } -func (t *SimulcastTrack) onRead(p *rtp.Packet, quality QualityLevel) { - t.mu.Lock() - defer t.mu.Unlock() - +func (t *SimulcastTrack) onRead(p rtp.Packet, quality QualityLevel) { for _, callback := range t.onReadCallbacks { - callback(*p, quality) + callback(p, quality) } } diff --git a/util.go b/util.go index d54b41d..6be3e20 100644 --- a/util.go +++ b/util.go @@ -107,7 +107,7 @@ func RegisterSimulcastHeaderExtensions(m *webrtc.MediaEngine, codecType webrtc.R } } -func IsKeyframe(codec string, packet *rtp.Packet) bool { +func IsKeyframe(codec string, packet rtp.Packet) bool { isIt1, isIt2 := Keyframe(codec, packet) return isIt1 && isIt2 } @@ -119,7 +119,7 @@ func IsKeyframe(codec string, packet *rtp.Packet) bool { // It returns (true, true) if that is the case, (false, true) if that is // definitely not the case, and (false, false) if the information cannot // be determined. -func Keyframe(codec string, packet *rtp.Packet) (bool, bool) { +func Keyframe(codec string, packet rtp.Packet) (bool, bool) { if strings.EqualFold(codec, "video/vp8") { var vp8 codecs.VP8Packet _, err := vp8.Unmarshal(packet.Payload) From 52297113d8fc3cb32130402d92e913d6582ec7f7 Mon Sep 17 00:00:00 2001 From: Yohan Totting Date: Wed, 22 Nov 2023 10:48:30 +0700 Subject: [PATCH 06/13] fix vad blocking --- examples/http-websocket/main.go | 2 +- pkg/interceptors/voiceactivedetector/vad.go | 26 ++++++++++++++------- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/examples/http-websocket/main.go b/examples/http-websocket/main.go index 15a92a1..2711179 100644 --- a/examples/http-websocket/main.go +++ b/examples/http-websocket/main.go @@ -66,7 +66,7 @@ func main() { sfuOpts := sfu.DefaultOptions() sfuOpts.EnableMux = false - sfuOpts.EnableBandwidthEstimator = false + sfuOpts.EnableBandwidthEstimator = true _, turnEnabled := os.LookupEnv("TURN_ENABLED") if turnEnabled { diff --git a/pkg/interceptors/voiceactivedetector/vad.go b/pkg/interceptors/voiceactivedetector/vad.go index a743423..ad00a83 100644 --- a/pkg/interceptors/voiceactivedetector/vad.go +++ b/pkg/interceptors/voiceactivedetector/vad.go @@ -3,6 +3,7 @@ package voiceactivedetector import ( "context" "sync" + "time" "github.com/pion/interceptor" "github.com/pion/rtp" @@ -63,9 +64,15 @@ func newVAD(ctx context.Context, i *Interceptor, streamInfo *interceptor.StreamI // once the tail margin close, stop send the packet. func (v *VoiceDetector) run() { go func() { + ticker := time.NewTicker(500 * time.Millisecond) ctx, cancel := context.WithCancel(v.context) v.cancel = cancel - defer cancel() + + defer func() { + ticker.Stop() + cancel() + }() + for { select { case <-ctx.Done(): @@ -73,24 +80,21 @@ func (v *VoiceDetector) run() { case voicePacket := <-v.channel: if v.isDetected(voicePacket) { // send all packets to callback - v.sendPacketsToCallback() - } else { - // drop packets in queue if pass the head margin - v.dropExpiredPackets() - + go v.sendPacketsToCallback() } + case <-ticker.C: + go v.dropExpiredPackets() } } }() } func (v *VoiceDetector) dropExpiredPackets() { - v.mu.Lock() - defer v.mu.Unlock() - loop: for { + v.mu.Lock() if len(v.VoicePackets) == 0 { + v.mu.Unlock() break loop } @@ -101,8 +105,10 @@ loop: // drop packet v.VoicePackets = v.VoicePackets[1:] } else { + v.mu.Unlock() break loop } + v.mu.Unlock() } } @@ -125,7 +131,9 @@ func (v *VoiceDetector) sendPacketsToCallback() { }) // clear packets + v.mu.Lock() v.VoicePackets = make([]VoicePacketData, 0) + v.mu.Unlock() } func (v *VoiceDetector) getPackets() []VoicePacketData { From c57a037300ff41686581adfc38ea9ee2fa34541e Mon Sep 17 00:00:00 2001 From: Yohan Totting Date: Wed, 22 Nov 2023 15:13:05 +0700 Subject: [PATCH 07/13] change onEnded with context --- bitratecontroller.go | 11 +- client.go | 23 ++- clientracklist.go | 12 +- clienttrack.go | 43 ++---- pkg/interceptors/voiceactivedetector/vad.go | 34 ++--- remotetrack.go | 15 +- scalableclienttrack.go | 7 + simulcastclienttrack.go | 7 + track.go | 153 +++++++------------- 9 files changed, 138 insertions(+), 167 deletions(-) diff --git a/bitratecontroller.go b/bitratecontroller.go index 1e8cfe1..3099904 100644 --- a/bitratecontroller.go +++ b/bitratecontroller.go @@ -332,10 +332,16 @@ func (bc *bitrateController) addClaim(clientTrack iClientTrack, quality QualityL bitrate: bitrate, } - clientTrack.OnTrackEnded(func() { + go func() { + ctx, cancel := context.WithCancel(clientTrack.Context()) + defer cancel() + <-ctx.Done() bc.removeClaim(clientTrack.ID()) + if bc.client.IsDebugEnabled() { + glog.Info("clienttrack: track ", clientTrack.ID(), " claim removed") + } clientTrack.Client().stats.removeSenderStats(clientTrack.ID()) - }) + }() return bc.claims[clientTrack.ID()], nil } @@ -345,6 +351,7 @@ func (bc *bitrateController) removeClaim(id string) { defer bc.mu.Unlock() if _, ok := bc.claims[id]; !ok { + glog.Error("bitrate: track ", id, " is not exists") return } diff --git a/client.go b/client.go index e2609f9..d522378 100644 --- a/client.go +++ b/client.go @@ -352,9 +352,12 @@ func NewClient(s *SFU, id string, name string, peerConnectionConfig webrtc.Confi track = newTrack(client.context, client.id, remoteTrack, s.pliInterval, onPLI, client.statsGetter, onStatsUpdated) - track.OnEnded(func() { + go func() { + ctx, cancel := context.WithCancel(track.Context()) + defer cancel() + <-ctx.Done() client.stats.removeReceiverStats(remoteTrack.ID()) - }) + }() if err := client.tracks.Add(track); err != nil { glog.Error("client: error add track ", err) @@ -378,9 +381,12 @@ func NewClient(s *SFU, id string, name string, peerConnectionConfig webrtc.Confi glog.Error("client: error add track ", err) } - track.OnEnded(func() { + go func() { + ctx, cancel := context.WithCancel(track.Context()) + defer cancel() + <-ctx.Done() client.stats.removeReceiverStats(remoteTrack.ID()) - }) + }() } else if simulcast, ok = track.(*SimulcastTrack); ok { simulcast.AddRemoteTrack(remoteTrack, client.statsGetter, onStatsUpdated) } @@ -697,7 +703,12 @@ func (c *Client) setClientTrack(t ITrack) iClientTrack { return nil } - t.OnEnded(func() { + go func() { + ctx, cancel := context.WithCancel(outputTrack.Context()) + defer cancel() + + <-ctx.Done() + if c == nil { return } @@ -720,7 +731,7 @@ func (c *Client) setClientTrack(t ITrack) iClientTrack { } c.renegotiate() - }) + }() // enable RTCP report and stats c.enableReportAndStats(transc.Sender(), outputTrack) diff --git a/clientracklist.go b/clientracklist.go index d5ef37d..bea6357 100644 --- a/clientracklist.go +++ b/clientracklist.go @@ -1,6 +1,9 @@ package sfu -import "sync" +import ( + "context" + "sync" +) type clientTrackList struct { mu sync.RWMutex @@ -11,9 +14,12 @@ func (l *clientTrackList) Add(track iClientTrack) { l.mu.RLock() defer l.mu.RUnlock() - track.OnTrackEnded(func() { + go func() { + ctx, cancel := context.WithCancel(track.Context()) + defer cancel() + <-ctx.Done() l.remove(track.ID()) - }) + }() l.tracks = append(l.tracks, track) } diff --git a/clienttrack.go b/clienttrack.go index 5e7744f..468cc9c 100644 --- a/clienttrack.go +++ b/clienttrack.go @@ -1,6 +1,7 @@ package sfu import ( + "context" "sync" "sync/atomic" @@ -12,14 +13,13 @@ import ( type iClientTrack interface { push(rtp rtp.Packet, quality QualityLevel) ID() string + Context() context.Context Kind() webrtc.RTPCodecType LocalTrack() *webrtc.TrackLocalStaticRTP IsScreen() bool IsSimulcast() bool IsScaleable() bool SetSourceType(TrackType) - OnTrackEnded(func()) - onTrackEnded() Client() *Client RequestPLI() SetMaxQuality(quality QualityLevel) @@ -27,21 +27,26 @@ type iClientTrack interface { } type clientTrack struct { - id string - mu sync.RWMutex - client *Client - kind webrtc.RTPCodecType - mimeType string - localTrack *webrtc.TrackLocalStaticRTP - remoteTrack *remoteTrack - isScreen *atomic.Bool - onTrackEndedCallbacks []func() + id string + context context.Context + cancel context.CancelFunc + mu sync.RWMutex + client *Client + kind webrtc.RTPCodecType + mimeType string + localTrack *webrtc.TrackLocalStaticRTP + remoteTrack *remoteTrack + isScreen *atomic.Bool } func (t *clientTrack) ID() string { return t.id } +func (t *clientTrack) Context() context.Context { + return t.context +} + func (t *clientTrack) Client() *Client { return t.client } @@ -76,22 +81,6 @@ func (t *clientTrack) SetSourceType(sourceType TrackType) { t.isScreen.Store(sourceType == TrackTypeScreen) } -func (t *clientTrack) OnTrackEnded(callback func()) { - t.mu.Lock() - defer t.mu.Unlock() - - t.onTrackEndedCallbacks = append(t.onTrackEndedCallbacks, callback) -} - -func (t *clientTrack) onTrackEnded() { - t.mu.Lock() - defer t.mu.Unlock() - - for _, callback := range t.onTrackEndedCallbacks { - callback() - } -} - func (t *clientTrack) IsSimulcast() bool { return false } diff --git a/pkg/interceptors/voiceactivedetector/vad.go b/pkg/interceptors/voiceactivedetector/vad.go index ad00a83..84ea194 100644 --- a/pkg/interceptors/voiceactivedetector/vad.go +++ b/pkg/interceptors/voiceactivedetector/vad.go @@ -34,7 +34,7 @@ type VoiceDetector struct { startDetected uint32 lastDetectedTS uint32 channel chan VoicePacketData - mu sync.Mutex + mu sync.RWMutex VoicePackets []VoicePacketData callbacks []func(activity VoiceActivity) } @@ -45,7 +45,7 @@ func newVAD(ctx context.Context, i *Interceptor, streamInfo *interceptor.StreamI interceptor: i, streamInfo: streamInfo, channel: make(chan VoicePacketData), - mu: sync.Mutex{}, + mu: sync.RWMutex{}, VoicePackets: make([]VoicePacketData, 0), callbacks: make([]func(VoiceActivity), 0), } @@ -80,7 +80,7 @@ func (v *VoiceDetector) run() { case voicePacket := <-v.channel: if v.isDetected(voicePacket) { // send all packets to callback - go v.sendPacketsToCallback() + v.sendPacketsToCallback() } case <-ticker.C: go v.dropExpiredPackets() @@ -131,14 +131,14 @@ func (v *VoiceDetector) sendPacketsToCallback() { }) // clear packets - v.mu.Lock() + v.mu.RLock() v.VoicePackets = make([]VoicePacketData, 0) - v.mu.Unlock() + v.mu.RUnlock() } func (v *VoiceDetector) getPackets() []VoicePacketData { - v.mu.Lock() - defer v.mu.Unlock() + v.mu.RLock() + defer v.mu.RUnlock() packets := make([]VoicePacketData, 0) packets = append(packets, v.VoicePackets...) @@ -152,13 +152,15 @@ func (v *VoiceDetector) onVoiceDetected(activity VoiceActivity) { } func (v *VoiceDetector) OnVoiceDetected(callback func(VoiceActivity)) { - v.mu.Lock() - defer v.mu.Unlock() + v.mu.RLock() + defer v.mu.RUnlock() v.callbacks = append(v.callbacks, callback) } func (v *VoiceDetector) isDetected(vp VoicePacketData) bool { + v.mu.RLock() v.VoicePackets = append(v.VoicePackets, vp) + v.mu.RUnlock() clockRate := v.streamInfo.ClockRate @@ -215,9 +217,6 @@ func (v *VoiceDetector) isDetected(vp VoicePacketData) bool { } func (v *VoiceDetector) addPacket(header *rtp.Header, audioLevel uint8) { - v.mu.Lock() - defer v.mu.Unlock() - v.channel <- VoicePacketData{ SequenceNo: header.SequenceNumber, Timestamp: header.Timestamp, @@ -226,23 +225,20 @@ func (v *VoiceDetector) addPacket(header *rtp.Header, audioLevel uint8) { } func (v *VoiceDetector) UpdateTrack(trackID, streamID string) { - v.mu.Lock() - defer v.mu.Unlock() + v.mu.RLock() + defer v.mu.RUnlock() v.trackID = trackID v.streamID = streamID } func (v *VoiceDetector) Stop() { - v.mu.Lock() - defer v.mu.Unlock() - v.cancel() } func (v *VoiceDetector) updateStreamInfo(streamInfo *interceptor.StreamInfo) { - v.mu.Lock() - defer v.mu.Unlock() + v.mu.RLock() + defer v.mu.RUnlock() v.streamInfo = streamInfo if streamInfo.ID != "" { diff --git a/remotetrack.go b/remotetrack.go index 846d3c2..6df8df9 100644 --- a/remotetrack.go +++ b/remotetrack.go @@ -55,17 +55,8 @@ func newRemoteTrack(ctx context.Context, track IRemoteTrack, pliInterval time.Du return rt } -func (t *remoteTrack) OnEnded(f func()) { - t.mu.Lock() - defer t.mu.Unlock() - - t.onEndedCallbacks = append(t.onEndedCallbacks, f) -} - -func (t *remoteTrack) onEnded() { - for _, f := range t.onEndedCallbacks { - f() - } +func (t *remoteTrack) Context() context.Context { + return t.context } func (t *remoteTrack) readRTP() { @@ -80,8 +71,6 @@ func (t *remoteTrack) readRTP() { default: rtp, _, readErr := t.track.ReadRTP() if readErr == io.EOF { - t.onEnded() - return } else if readErr != nil { glog.Error("error reading rtp: ", readErr.Error()) diff --git a/scalableclienttrack.go b/scalableclienttrack.go index d754da8..c35ab6b 100644 --- a/scalableclienttrack.go +++ b/scalableclienttrack.go @@ -1,6 +1,7 @@ package sfu import ( + "context" "sync" "github.com/golang/glog" @@ -78,6 +79,8 @@ func DefaultQualityPreset() QualityPreset { type scaleableClientTrack struct { id string + context context.Context + cancel context.CancelFunc mu sync.RWMutex client *Client kind webrtc.RTPCodecType @@ -106,6 +109,10 @@ func (t *scaleableClientTrack) Client() *Client { return t.client } +func (t *scaleableClientTrack) Context() context.Context { + return t.context +} + func (t *scaleableClientTrack) writeRTP(p rtp.Packet) { t.lastTimestamp = p.Timestamp t.sequenceNumber = p.SequenceNumber diff --git a/simulcastclienttrack.go b/simulcastclienttrack.go index ac4ed94..ad07178 100644 --- a/simulcastclienttrack.go +++ b/simulcastclienttrack.go @@ -1,6 +1,7 @@ package sfu import ( + "context" "sync" "sync/atomic" "time" @@ -14,6 +15,8 @@ type simulcastClientTrack struct { id string mu sync.RWMutex client *Client + context context.Context + cancel context.CancelFunc kind webrtc.RTPCodecType mimeType string localTrack *webrtc.TrackLocalStaticRTP @@ -34,6 +37,10 @@ func (t *simulcastClientTrack) Client() *Client { return t.client } +func (t *simulcastClientTrack) Context() context.Context { + return t.context +} + func (t *simulcastClientTrack) isFirstKeyframePacket(p rtp.Packet) bool { isKeyframe := IsKeyframe(t.mimeType, p) diff --git a/track.go b/track.go index 044ca30..21eb3bd 100644 --- a/track.go +++ b/track.go @@ -56,12 +56,13 @@ type ITrack interface { Kind() webrtc.RTPCodecType MimeType() string TotalTracks() int - OnEnded(func()) - onEnded() + Context() context.Context Relay(func(webrtc.SSRC, rtp.Packet)) } type Track struct { + context context.Context + cancel context.CancelFunc mu sync.Mutex base baseTrack remoteTrack *remoteTrack @@ -72,6 +73,8 @@ type Track struct { func newTrack(ctx context.Context, clientID string, trackRemote IRemoteTrack, pliInterval time.Duration, onPLI func() error, stats stats.Getter, onStatsUpdated func(*stats.Stats)) ITrack { ctList := newClientTrackList() + localCtx, cancel := context.WithCancel(ctx) + baseTrack := baseTrack{ id: trackRemote.ID(), isScreen: &atomic.Bool{}, @@ -85,6 +88,8 @@ func newTrack(ctx context.Context, clientID string, trackRemote IRemoteTrack, pl t := &Track{ mu: sync.Mutex{}, + context: localCtx, + cancel: cancel, base: baseTrack, onReadCallbacks: make([]func(rtp.Packet, QualityLevel), 0), onEndedCallbacks: make([]func(), 0), @@ -100,11 +105,7 @@ func newTrack(ctx context.Context, clientID string, trackRemote IRemoteTrack, pl t.onRead(p, QualityHigh) } - t.remoteTrack = newRemoteTrack(ctx, trackRemote, pliInterval, onPLI, stats, onStatsUpdated, onRead) - - t.remoteTrack.OnEnded(func() { - t.onEnded() - }) + t.remoteTrack = newRemoteTrack(localCtx, trackRemote, pliInterval, onPLI, stats, onStatsUpdated, onRead) return t } @@ -113,6 +114,10 @@ func (t *Track) ClientID() string { return t.base.clientid } +func (t *Track) Context() context.Context { + return t.context +} + func (t *Track) createLocalTrack() *webrtc.TrackLocalStaticRTP { track, newTrackErr := webrtc.NewTrackLocalStaticRTP(t.remoteTrack.track.Codec().RTPCodecCapability, t.base.id, t.base.streamid) if newTrackErr != nil { @@ -187,8 +192,12 @@ func (t *Track) TotalTracks() int { func (t *Track) subscribe(c *Client) iClientTrack { var ct iClientTrack + ctx, cancel := context.WithCancel(t.Context()) + if t.MimeType() == webrtc.MimeTypeVP9 { ct = &scaleableClientTrack{ + context: ctx, + cancel: cancel, mu: sync.RWMutex{}, id: t.base.id, kind: t.base.kind, @@ -207,15 +216,16 @@ func (t *Track) subscribe(c *Client) iClientTrack { isScreen.Store(t.IsScreen()) ct = &clientTrack{ - id: t.base.id, - mu: sync.RWMutex{}, - client: c, - kind: t.base.kind, - mimeType: t.remoteTrack.track.Codec().MimeType, - localTrack: t.createLocalTrack(), - remoteTrack: t.remoteTrack, - isScreen: isScreen, - onTrackEndedCallbacks: make([]func(), 0), + id: t.base.id, + context: ctx, + cancel: cancel, + mu: sync.RWMutex{}, + client: c, + kind: t.base.kind, + mimeType: t.remoteTrack.track.Codec().MimeType, + localTrack: t.createLocalTrack(), + remoteTrack: t.remoteTrack, + isScreen: isScreen, } } @@ -228,15 +238,9 @@ func (t *Track) subscribe(c *Client) iClientTrack { }) } - t.remoteTrack.OnEnded(func() { - ct.onTrackEnded() - }) - go func() { - clientCtx, cancel := context.WithCancel(c.context) defer cancel() - <-clientCtx.Done() - ct.onTrackEnded() + <-ctx.Done() }() t.base.clientTracks.Add(ct) @@ -255,22 +259,6 @@ func (t *Track) SetAsProcessed() { t.base.isProcessed = true } -func (t *Track) onEnded() { - t.mu.Lock() - defer t.mu.Unlock() - - for _, callback := range t.onEndedCallbacks { - callback() - } -} - -func (t *Track) OnEnded(callback func()) { - t.mu.Lock() - defer t.mu.Unlock() - - t.onEndedCallbacks = append(t.onEndedCallbacks, callback) -} - func (t *Track) OnRead(callback func(rtp.Packet, QualityLevel)) { t.mu.Lock() defer t.mu.Unlock() @@ -318,7 +306,6 @@ type SimulcastTrack struct { lastMidKeyframeTS *atomic.Int64 lastLowKeyframeTS *atomic.Int64 onAddedRemoteTrackCallbacks []func(*remoteTrack) - onEndedCallbacks []func() onReadCallbacks []func(rtp.Packet, QualityLevel) pliInterval time.Duration onPLI func() error @@ -346,7 +333,6 @@ func newSimulcastTrack(ctx context.Context, clientid string, track IRemoteTrack, lastLowKeyframeTS: &atomic.Int64{}, onTrackCompleteCallbacks: make([]func(), 0), onAddedRemoteTrackCallbacks: make([]func(*remoteTrack), 0), - onEndedCallbacks: make([]func(), 0), onReadCallbacks: make([]func(rtp.Packet, QualityLevel), 0), pliInterval: pliInterval, onPLI: onPLI, @@ -361,6 +347,10 @@ func (t *SimulcastTrack) ClientID() string { return t.base.clientid } +func (t *SimulcastTrack) Context() context.Context { + return t.context +} + func (t *SimulcastTrack) onRemoteTrackAdded(f func(*remoteTrack)) { t.mu.Lock() defer t.mu.Unlock() @@ -473,45 +463,42 @@ func (t *SimulcastTrack) AddRemoteTrack(track IRemoteTrack, stats stats.Getter, t.remoteTrackHigh = remoteTrack t.mu.Unlock() - remoteTrack.OnEnded(func() { + go func() { + ctx, cancel := context.WithCancel(remoteTrack.Context()) + defer cancel() + <-ctx.Done() t.mu.Lock() t.remoteTrackHigh = nil t.mu.Unlock() - - if t.remoteTrackHigh == nil && t.remoteTrackMid == nil && t.remoteTrackLow == nil { - t.onEnded() - } - }) + }() case QualityMid: t.mu.Lock() t.remoteTrackMid = remoteTrack t.mu.Unlock() - remoteTrack.OnEnded(func() { + go func() { + ctx, cancel := context.WithCancel(remoteTrack.Context()) + defer cancel() + <-ctx.Done() t.mu.Lock() t.remoteTrackMid = nil t.mu.Unlock() - - if t.remoteTrackHigh == nil && t.remoteTrackMid == nil && t.remoteTrackLow == nil { - t.onEnded() - } - }) + }() case QualityLow: t.mu.Lock() t.remoteTrackLow = remoteTrack t.mu.Unlock() - remoteTrack.OnEnded(func() { + go func() { + ctx, cancel := context.WithCancel(remoteTrack.Context()) + defer cancel() + <-ctx.Done() t.mu.Lock() t.remoteTrackLow = nil t.mu.Unlock() - - if t.remoteTrackHigh == nil && t.remoteTrackMid == nil && t.remoteTrackLow == nil { - t.onEnded() - } - }) + }() default: glog.Warning("client: unknown track quality ", track.RID()) return nil @@ -559,9 +546,13 @@ func (t *SimulcastTrack) subscribe(client *Client) iClientTrack { lastTimestamp := &atomic.Uint32{} + ctx, cancel := context.WithCancel(t.Context()) + ct := &simulcastClientTrack{ mu: sync.RWMutex{}, id: t.base.id, + context: ctx, + cancel: cancel, kind: t.base.kind, mimeType: t.base.codec.MimeType, client: client, @@ -581,30 +572,11 @@ func (t *SimulcastTrack) subscribe(client *Client) iClientTrack { ct.SetMaxQuality(QualityHigh) - if t.remoteTrackLow != nil { - t.remoteTrackLow.OnEnded(func() { - ct.onTrackEnded() - }) - } - go func() { - clientCtx, cancel := context.WithCancel(client.context) defer cancel() - <-clientCtx.Done() - ct.onTrackEnded() + <-ctx.Done() }() - if t.remoteTrackMid != nil { - t.remoteTrackMid.OnEnded(func() { - ct.onTrackEnded() - }) - } - if t.remoteTrackHigh != nil { - t.remoteTrackHigh.OnEnded(func() { - ct.onTrackEnded() - }) - } - t.base.clientTracks.Add(ct) return ct @@ -730,22 +702,6 @@ func (t *SimulcastTrack) MimeType() string { return t.base.codec.MimeType } -func (t *SimulcastTrack) OnEnded(f func()) { - t.mu.Lock() - defer t.mu.Unlock() - - t.onEndedCallbacks = append(t.onEndedCallbacks, f) -} - -func (t *SimulcastTrack) onEnded() { - t.mu.Lock() - defer t.mu.Unlock() - - for _, callback := range t.onEndedCallbacks { - callback() - } -} - func (t *SimulcastTrack) OnRead(callback func(rtp.Packet, QualityLevel)) { t.mu.Lock() defer t.mu.Unlock() @@ -850,9 +806,12 @@ func (t *trackList) Add(track ITrack) error { t.tracks[id] = track - track.OnEnded(func() { + go func() { + ctx, cancel := context.WithCancel(track.Context()) + defer cancel() + <-ctx.Done() t.remove([]string{id}) - }) + }() return nil } From 869ae1ce1c71ed4dd28e2b1a05a58c411e9c352f Mon Sep 17 00:00:00 2001 From: Yohan Totting Date: Wed, 22 Nov 2023 15:54:59 +0700 Subject: [PATCH 08/13] use onTargetBitrateChanged callback --- bitratecontroller.go | 54 ++++++++++++++++++++++++++++++++++---------- client.go | 14 +++++++----- 2 files changed, 50 insertions(+), 18 deletions(-) diff --git a/bitratecontroller.go b/bitratecontroller.go index 3099904..431a482 100644 --- a/bitratecontroller.go +++ b/bitratecontroller.go @@ -8,6 +8,7 @@ import ( "time" "github.com/golang/glog" + "github.com/pion/interceptor/pkg/cc" "github.com/pion/webrtc/v3" ) @@ -69,19 +70,22 @@ func (c *bitrateClaim) pushbackDelayCounter() { } type bitrateController struct { - mu sync.RWMutex - client *Client - claims map[string]*bitrateClaim + mu sync.RWMutex + lastBitrateAdjustmentTS time.Time + client *Client + claims map[string]*bitrateClaim } -func newbitrateController(client *Client, intervalMonitor time.Duration) *bitrateController { +func newbitrateController(client *Client, intervalMonitor time.Duration, useBandwidthEstimation bool) *bitrateController { bc := &bitrateController{ mu: sync.RWMutex{}, client: client, claims: make(map[string]*bitrateClaim, 0), } - bc.start() + if !useBandwidthEstimation { + bc.start() + } return bc } @@ -244,6 +248,7 @@ func (bc *bitrateController) addAudioClaims(clientTracks []iClientTrack) (leftTr return leftTracks, nil } +// this should never return QualityNone becaus it will delay onTrack event func (bc *bitrateController) getDistributedQuality(totalTracks int) QualityLevel { if totalTracks == 0 { return 0 @@ -255,9 +260,7 @@ func (bc *bitrateController) getDistributedQuality(totalTracks int) QualityLevel bitrateConfig := bc.client.SFU().bitratesConfig - if distributedBandwidth < bitrateConfig.VideoLow { - return QualityNone - } else if distributedBandwidth < bitrateConfig.VideoMid { + if distributedBandwidth < bitrateConfig.VideoMid { return QualityLow } else if distributedBandwidth < bitrateConfig.VideoHigh { return QualityMid @@ -281,10 +284,6 @@ func (bc *bitrateController) addClaims(clientTracks []iClientTrack) error { for _, clientTrack := range leftTracks { if clientTrack.Kind() == webrtc.RTPCodecTypeVideo { trackQuality := bc.getDistributedQuality(len(leftTracks) - claimed) - if clientTrack.IsScreen() && trackQuality == QualityNone { - trackQuality = QualityLow - } - if _, ok := bc.claims[clientTrack.ID()]; ok { errors = append(errors, ErrAlreadyClaimed) continue @@ -447,6 +446,37 @@ func (bc *bitrateController) start() { }() } +func (bc *bitrateController) needIncreaseBitrate(availableBw uint32) bool { + claims := bc.Claims() + + for _, claim := range claims { + if claim.quality < QualityHigh { + if bc.isEnoughBandwidthToIncrase(availableBw, claim) { + return true + } + } + } + + return false +} + +func (bc *bitrateController) MonitorBandwidth(estimator cc.BandwidthEstimator) { + estimator.OnTargetBitrateChange(func(bw int) { + + availableBw := uint32(bw) - bc.totalSentBitrates() + + if bc.totalSentBitrates() > uint32(bw) || (bc.totalSentBitrates() < uint32(bw) && bc.needIncreaseBitrate(availableBw)) { + if time.Since(bc.lastBitrateAdjustmentTS) > 500*time.Millisecond { + glog.Info("brcontroller: target bitrate changed to ", ThousandSeparator(bw), ", will check and adjust bitrates") + + bc.checkAndAdjustBitrates() + + bc.lastBitrateAdjustmentTS = time.Now() + } + } + }) +} + // checkAndAdjustBitrates will check if the available bandwidth is enough to send the current bitrate // if not then it will try to reduce one by one of simulcast track quality until it fit the bandwidth // if the bandwidth is enough to send the current bitrate, then it will try to increase the bitrate diff --git a/client.go b/client.go index d522378..492fb7f 100644 --- a/client.go +++ b/client.go @@ -304,7 +304,7 @@ func NewClient(s *SFU, id string, name string, peerConnectionConfig webrtc.Confi client.stats = newClientStats(client) - client.bitrateController = newbitrateController(client, s.pliInterval) + client.bitrateController = newbitrateController(client, s.pliInterval, s.enableBandwidthEstimator) if s.enableBandwidthEstimator { go func() { @@ -313,6 +313,8 @@ func NewClient(s *SFU, id string, name string, peerConnectionConfig webrtc.Confi defer client.mu.Unlock() client.estimator = estimator + + client.bitrateController.MonitorBandwidth(estimator) }() } @@ -391,12 +393,12 @@ func NewClient(s *SFU, id string, name string, peerConnectionConfig webrtc.Confi simulcast.AddRemoteTrack(remoteTrack, client.statsGetter, onStatsUpdated) } - // // only process track when the lowest quality is available - // simulcast.mu.Lock() - // isLowAvailable := simulcast.remoteTrackLow != nil - // simulcast.mu.Unlock() + // only process track when the lowest quality is available + simulcast.mu.Lock() + isLowAvailable := simulcast.remoteTrackLow != nil + simulcast.mu.Unlock() - if !track.IsProcessed() { + if !track.IsProcessed() && isLowAvailable { client.onTrack(track) track.SetAsProcessed() } From a046021a24ba58141bdd994620738971bda9f9ff Mon Sep 17 00:00:00 2001 From: Yohan Totting Date: Wed, 22 Nov 2023 16:00:03 +0700 Subject: [PATCH 09/13] fix panic nil --- client.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/client.go b/client.go index 492fb7f..7b61b5a 100644 --- a/client.go +++ b/client.go @@ -389,6 +389,11 @@ func NewClient(s *SFU, id string, name string, peerConnectionConfig webrtc.Confi <-ctx.Done() client.stats.removeReceiverStats(remoteTrack.ID()) }() + + if simulcast, ok = track.(*SimulcastTrack); !ok { + glog.Error("client: error track is not simulcast track") + } + } else if simulcast, ok = track.(*SimulcastTrack); ok { simulcast.AddRemoteTrack(remoteTrack, client.statsGetter, onStatsUpdated) } From 92e0a4ae5aa686bc1db6328ed7ffc1e07e76699a Mon Sep 17 00:00:00 2001 From: Yohan Totting Date: Wed, 22 Nov 2023 16:08:24 +0700 Subject: [PATCH 10/13] adjust test to follow latest config --- main_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/main_test.go b/main_test.go index 657d5e8..ec55550 100644 --- a/main_test.go +++ b/main_test.go @@ -23,8 +23,9 @@ func TestMain(m *testing.M) { // create room manager first before create new room roomManager = NewManager(ctx, "test", Options{ - WebRTCPort: 40004, ConnectRemoteRoomTimeout: 30 * time.Second, + EnableMux: false, + EnableBandwidthEstimator: true, IceServers: DefaultTestIceServers(), }) From e628b2f75b43283252ffec44f41607ec83d4fca5 Mon Sep 17 00:00:00 2001 From: Yohan Totting Date: Wed, 22 Nov 2023 16:11:23 +0700 Subject: [PATCH 11/13] disable datachannel parallel test --- datachannel_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datachannel_test.go b/datachannel_test.go index c467a91..cdf5b88 100644 --- a/datachannel_test.go +++ b/datachannel_test.go @@ -11,7 +11,7 @@ import ( ) func TestRoomDataChannel(t *testing.T) { - t.Parallel() + // t.Parallel() roomID := roomManager.CreateRoomID() roomName := "test-room" @@ -104,7 +104,7 @@ Loop: } func TestRoomDataChannelWithClientID(t *testing.T) { - t.Parallel() + // t.Parallel() roomID := roomManager.CreateRoomID() roomName := "test-room" From b6fa0870b5344af8802edb2a028c8a2dbd2cbb9d Mon Sep 17 00:00:00 2001 From: Yohan Totting Date: Wed, 22 Nov 2023 16:19:03 +0700 Subject: [PATCH 12/13] update datachannel test --- datachannel_test.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/datachannel_test.go b/datachannel_test.go index cdf5b88..1ebb3a6 100644 --- a/datachannel_test.go +++ b/datachannel_test.go @@ -11,7 +11,7 @@ import ( ) func TestRoomDataChannel(t *testing.T) { - // t.Parallel() + t.Parallel() roomID := roomManager.CreateRoomID() roomName := "test-room" @@ -23,6 +23,9 @@ func TestRoomDataChannel(t *testing.T) { require.NoError(t, err, "error creating room: %v", err) ctx := testRoom.sfu.context + err = testRoom.CreateDataChannel("chat", DefaultDataChannelOptions()) + require.NoError(t, err) + pc1, client1, _ := CreateDataPair(ctx, testRoom, roomManager.options.IceServers, "peer1") pc2, client2, _ := CreateDataPair(ctx, testRoom, roomManager.options.IceServers, "peer2") @@ -74,9 +77,6 @@ func TestRoomDataChannel(t *testing.T) { require.True(t, isConnected) - err = testRoom.CreateDataChannel("chat", DefaultDataChannelOptions()) - require.NoError(t, err) - // make sure to return error on creating data channel with same label err = testRoom.CreateDataChannel("chat", DefaultDataChannelOptions()) require.Error(t, err) @@ -104,7 +104,7 @@ Loop: } func TestRoomDataChannelWithClientID(t *testing.T) { - // t.Parallel() + t.Parallel() roomID := roomManager.CreateRoomID() roomName := "test-room" From 3e81d98e1d9fe1f2b156efcc20aa66c27bf94f09 Mon Sep 17 00:00:00 2001 From: Yohan Totting Date: Wed, 22 Nov 2023 16:26:54 +0700 Subject: [PATCH 13/13] adjust wait connected --- datachannel_test.go | 2 +- testhelper.go | 37 ++++++++++++++++++++++++------------- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/datachannel_test.go b/datachannel_test.go index 1ebb3a6..d74f76a 100644 --- a/datachannel_test.go +++ b/datachannel_test.go @@ -63,7 +63,7 @@ func TestRoomDataChannel(t *testing.T) { connected := WaitConnected(ctx, []*webrtc.PeerConnection{pc1, pc2}) - timeoutConnected, cancelTimeoutConnected := context.WithTimeout(ctx, 30*time.Second) + timeoutConnected, cancelTimeoutConnected := context.WithTimeout(ctx, 40*time.Second) isConnected := false select { diff --git a/testhelper.go b/testhelper.go index 88178c5..71c6723 100644 --- a/testhelper.go +++ b/testhelper.go @@ -421,7 +421,7 @@ func CreatePeerPair(ctx context.Context, room *Room, iceServers []webrtc.ICEServ pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { if state == webrtc.PeerConnectionStateClosed || state == webrtc.PeerConnectionStateFailed { - glog.Info("test: peer connection closed ", peerName) + glog.Info("test: peer connection ", peerName, " stated changed ", state) if client != nil { _ = room.StopClient(client.ID()) cancelClient() @@ -576,7 +576,13 @@ func CreateDataPair(ctx context.Context, room *Room, iceServers []webrtc.ICEServ ICEServers: iceServers, }) - pc.AddTransceiverFromKind(webrtc.RTPCodecTypeVideo, webrtc.RtpTransceiverInit{Direction: webrtc.RTPTransceiverDirectionRecvonly}) + if _, err := pc.AddTransceiverFromKind(webrtc.RTPCodecTypeVideo, webrtc.RtpTransceiverInit{Direction: webrtc.RTPTransceiverDirectionRecvonly}); err != nil { + panic(err) + } + + if _, err := pc.AddTransceiverFromKind(webrtc.RTPCodecTypeAudio, webrtc.RtpTransceiverInit{Direction: webrtc.RTPTransceiverDirectionRecvonly}); err != nil { + panic(err) + } pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { if state == webrtc.PeerConnectionStateClosed || state == webrtc.PeerConnectionStateFailed { @@ -638,17 +644,6 @@ func WaitConnected(ctx context.Context, peers []*webrtc.PeerConnection) chan boo connectedCount := 0 ctxx, cancel := context.WithCancel(ctx) - for _, pc := range peers { - pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { - if state == webrtc.PeerConnectionStateConnected { - connected <- true - } else if state == webrtc.PeerConnectionStateFailed || state == webrtc.PeerConnectionStateClosed { - waitChan <- false - cancel() - } - }) - } - go func() { defer cancel() @@ -667,5 +662,21 @@ func WaitConnected(ctx context.Context, peers []*webrtc.PeerConnection) chan boo } }() + for _, pc := range peers { + if pc.ConnectionState() == webrtc.PeerConnectionStateConnected { + connected <- true + } else { + pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { + if state == webrtc.PeerConnectionStateConnected { + connected <- true + } else if state == webrtc.PeerConnectionStateFailed || state == webrtc.PeerConnectionStateClosed { + waitChan <- false + cancel() + } + }) + } + + } + return waitChan }