diff --git a/client.go b/client.go index 093a6c2..5b5460e 100644 --- a/client.go +++ b/client.go @@ -169,6 +169,7 @@ type Client struct { onRenegotiation func(context.Context, webrtc.SessionDescription) (webrtc.SessionDescription, error) onAllowedRemoteRenegotiation func() onTracksAvailableCallbacks []func([]ITrack) + onTracksReadyCallbacks []func([]ITrack) onNetworkConditionChangedFunc func(networkmonitor.NetworkConditionType) // onTrack is used by SFU to take action when a new track is added to the client onTrack func(ITrack) @@ -597,7 +598,7 @@ func NewClient(s *SFU, id string, name string, peerConnectionConfig webrtc.Confi }) peerConnection.OnNegotiationNeeded(func() { - client.renegotiate() + client.renegotiate(false) }) return client @@ -684,7 +685,7 @@ func (c *Client) Negotiate(offer webrtc.SessionDescription) (*webrtc.SessionDesc defer func() { c.isInRemoteNegotiation.Store(false) if c.negotiationNeeded.Load() { - c.renegotiate() + c.renegotiate(false) } }() @@ -844,7 +845,7 @@ func (c *Client) OnRenegotiation(callback func(context.Context, webrtc.SessionDe c.onRenegotiation = callback } -func (c *Client) renegotiate() { +func (c *Client) renegotiate(offerFlexFec bool) { c.log.Debug("client: renegotiate") c.negotiationNeeded.Store(true) @@ -901,6 +902,12 @@ func (c *Client) renegotiate() { return } + if offerFlexFec { + // munge the offer to include FlexFEC + // get the payload code of video track + + } + // Sets the LocalDescription, and starts our UDP listeners err = c.peerConnection.PC().SetLocalDescription(offer) if err != nil { @@ -1366,6 +1373,7 @@ func (c *Client) SetTracksSourceType(trackTypes map[string]TrackType) { // broadcast to other clients available tracks from this client c.log.Debugf("client: %s set source tracks %d", c.ID(), len(availableTracks)) c.sfu.onTracksAvailable(c.ID(), availableTracks) + c.onTracksReady(availableTracks) } } @@ -1708,6 +1716,7 @@ func (c *Client) SFU() *SFU { // OnTracksAvailable event is called when the SFU is trying to publish new tracks to the client. // The client then can subscribe to the tracks by calling `client.SubscribeTracks()` method. +// The callback will receive the list of tracks from other clients that are available to subscribe. func (c *Client) OnTracksAvailable(callback func([]ITrack)) { c.mu.Lock() defer c.mu.Unlock() @@ -1720,6 +1729,20 @@ func (c *Client) onTracksAvailable(tracks []ITrack) { } } +// OnTracksReady event is called when the client's tracks are use from the client +// This can be use to hook a processing like transcription/video processing to client published tracks +func (c *Client) OnTracksReady(callback func([]ITrack)) { + c.mu.Lock() + defer c.mu.Unlock() + c.onTracksReadyCallbacks = append(c.onTracksReadyCallbacks, callback) +} + +func (c *Client) onTracksReady(tracks []ITrack) { + for _, callback := range c.onTracksReadyCallbacks { + callback(tracks) + } +} + // OnVoiceDetected event is called when the SFU is detecting voice activity in the room. // The callback will receive the voice activity data that can be use for visual indicator of current speaker. func (c *Client) OnVoiceSentDetected(callback func(activity voiceactivedetector.VoiceActivity)) { diff --git a/packetbuffers.go b/packetbuffers.go index e8b95d9..3238fdb 100644 --- a/packetbuffers.go +++ b/packetbuffers.go @@ -17,15 +17,14 @@ var ( ErrPacketDuplicate = errors.New("packetbuffers: packet is duplicate") ) -type packet struct { - packet *rtppool.RetainablePacket +type Packet struct { + Packet *rtppool.RetainablePacket addedTime time.Time } // buffer ring for cached packets -type packetBuffers struct { +type PacketBuffers struct { context context.Context - cancel context.CancelFunc init bool mu sync.RWMutex buffers *list.List @@ -35,7 +34,7 @@ type packetBuffers struct { minLatency time.Duration // max duration to wait before sending maxLatency time.Duration - oldestPacket *packet + oldestPacket *Packet initSequence uint16 packetCount uint64 waitTimeMu sync.RWMutex @@ -51,9 +50,9 @@ type packetBuffers struct { const waitTimeSize = 2500 -func newPacketBuffers(ctx context.Context, minLatency, maxLatency time.Duration, dynamicLatency bool, log logging.LeveledLogger) *packetBuffers { +func NewPacketBuffers(ctx context.Context, minLatency, maxLatency time.Duration, dynamicLatency bool, log logging.LeveledLogger) *PacketBuffers { ctx, cancel := context.WithCancel(ctx) - p := &packetBuffers{ + p := &PacketBuffers{ context: ctx, mu: sync.RWMutex{}, buffers: list.New(), @@ -77,24 +76,24 @@ func newPacketBuffers(ctx context.Context, minLatency, maxLatency time.Duration, return p } -func (p *packetBuffers) MaxLatency() time.Duration { +func (p *PacketBuffers) MaxLatency() time.Duration { p.latencyMu.RLock() defer p.latencyMu.RUnlock() return p.maxLatency } -func (p *packetBuffers) MinLatency() time.Duration { +func (p *PacketBuffers) MinLatency() time.Duration { p.latencyMu.RLock() defer p.latencyMu.RUnlock() return p.minLatency } -func (p *packetBuffers) Initiated() bool { +func (p *PacketBuffers) Initiated() bool { p.mu.RLock() defer p.mu.RUnlock() return p.init } -func (p *packetBuffers) Add(pkt *rtppool.RetainablePacket) error { +func (p *PacketBuffers) Add(pkt *rtppool.RetainablePacket) error { p.mu.Lock() defer func() { p.checkOrderedPacketAndRecordTimes() @@ -110,8 +109,8 @@ func (p *packetBuffers) Add(pkt *rtppool.RetainablePacket) error { p.log.Warnf("packet cache: packet sequence ", pkt.Header().SequenceNumber, " is too late, last sent was ", p.lastSequenceNumber, ", will not adding the packet") // add to to front of the list to make sure pop take it first - p.buffers.PushFront(&packet{ - packet: pkt, + p.buffers.PushFront(&Packet{ + Packet: pkt, addedTime: time.Now(), }) @@ -119,8 +118,8 @@ func (p *packetBuffers) Add(pkt *rtppool.RetainablePacket) error { } if p.buffers.Len() == 0 { - p.buffers.PushBack(&packet{ - packet: pkt, + p.buffers.PushBack(&Packet{ + Packet: pkt, addedTime: time.Now(), }) @@ -129,84 +128,84 @@ func (p *packetBuffers) Add(pkt *rtppool.RetainablePacket) error { // add packet in order Loop: for e := p.buffers.Back(); e != nil; e = e.Prev() { - currentPkt := e.Value.(*packet) + currentpkt := e.Value.(*Packet) - if err := currentPkt.packet.Retain(); err != nil { + if err := currentpkt.Packet.Retain(); err != nil { // already released continue } - if currentPkt.packet.Header().SequenceNumber == pkt.Header().SequenceNumber { + if currentpkt.Packet.Header().SequenceNumber == pkt.Header().SequenceNumber { // p.log.Warnf("packet cache: packet sequence ", pkt.SequenceNumber, " already exists in the cache, will not adding the packet") - currentPkt.packet.Release() + currentpkt.Packet.Release() return ErrPacketDuplicate } - if currentPkt.packet.Header().SequenceNumber < pkt.Header().SequenceNumber && pkt.Header().SequenceNumber-currentPkt.packet.Header().SequenceNumber < uint16SizeHalf { - p.buffers.InsertAfter(&packet{ - packet: pkt, + if currentpkt.Packet.Header().SequenceNumber < pkt.Header().SequenceNumber && pkt.Header().SequenceNumber-currentpkt.Packet.Header().SequenceNumber < uint16SizeHalf { + p.buffers.InsertAfter(&Packet{ + Packet: pkt, addedTime: time.Now(), }, e) - currentPkt.packet.Release() + currentpkt.Packet.Release() break Loop } - if currentPkt.packet.Header().SequenceNumber-pkt.Header().SequenceNumber > uint16SizeHalf { - p.buffers.InsertAfter(&packet{ - packet: pkt, + if currentpkt.Packet.Header().SequenceNumber-pkt.Header().SequenceNumber > uint16SizeHalf { + p.buffers.InsertAfter(&Packet{ + Packet: pkt, addedTime: time.Now(), }, e) - currentPkt.packet.Release() + currentpkt.Packet.Release() break Loop } if e.Prev() == nil { - p.buffers.PushFront(&packet{ - packet: pkt, + p.buffers.PushFront(&Packet{ + Packet: pkt, addedTime: time.Now(), }) - currentPkt.packet.Release() + currentpkt.Packet.Release() break Loop } - currentPkt.packet.Release() + currentpkt.Packet.Release() } return nil } -func (p *packetBuffers) pop(el *list.Element) *packet { - pkt := el.Value.(*packet) +func (p *PacketBuffers) pop(el *list.Element) *Packet { + pkt := el.Value.(*Packet) - if err := pkt.packet.Retain(); err != nil { + if err := pkt.Packet.Retain(); err != nil { // already released return nil } defer func() { - pkt.packet.Release() + pkt.Packet.Release() }() // make sure packet is not late - if IsRTPPacketLate(pkt.packet.Header().SequenceNumber, p.lastSequenceNumber) { - p.log.Warnf("packet cache: packet sequence ", pkt.packet.Header().SequenceNumber, " is too late, last sent was ", p.lastSequenceNumber) + if IsRTPPacketLate(pkt.Packet.Header().SequenceNumber, p.lastSequenceNumber) { + p.log.Warnf("packet cache: packet sequence ", pkt.Packet.Header().SequenceNumber, " is too late, last sent was ", p.lastSequenceNumber) } - if p.init && pkt.packet.Header().SequenceNumber > p.lastSequenceNumber && pkt.packet.Header().SequenceNumber-p.lastSequenceNumber > 1 { + if p.init && pkt.Packet.Header().SequenceNumber > p.lastSequenceNumber && pkt.Packet.Header().SequenceNumber-p.lastSequenceNumber > 1 { // make sure packet has no gap - p.log.Warnf("packet cache: packet sequence ", pkt.packet.Header().SequenceNumber, " has a gap with last sent ", p.lastSequenceNumber) + p.log.Warnf("packet cache: packet sequence ", pkt.Packet.Header().SequenceNumber, " has a gap with last sent ", p.lastSequenceNumber) } p.mu.Lock() - p.lastSequenceNumber = pkt.packet.Header().SequenceNumber + p.lastSequenceNumber = pkt.Packet.Header().SequenceNumber p.packetCount++ p.mu.Unlock() - if p.oldestPacket != nil && p.oldestPacket.packet.Header().SequenceNumber == pkt.packet.Header().SequenceNumber { + if p.oldestPacket != nil && p.oldestPacket.Packet.Header().SequenceNumber == pkt.Packet.Header().SequenceNumber { // oldest packet will be remove, find the next oldest packet in the list p.mu.RLock() for e := el.Next(); e != nil; e = e.Next() { - packet := e.Value.(*packet) + packet := e.Value.(*Packet) if packet.addedTime.After(p.oldestPacket.addedTime) { p.oldestPacket = packet } @@ -223,8 +222,8 @@ func (p *packetBuffers) pop(el *list.Element) *packet { return pkt } -func (p *packetBuffers) flush() []*packet { - packets := make([]*packet, 0) +func (p *PacketBuffers) flush() []*Packet { + packets := make([]*Packet, 0) if p.oldestPacket != nil && time.Since(p.oldestPacket.addedTime) > p.maxLatency { // we have waited too long, we should send the packets @@ -244,10 +243,10 @@ Loop: return packets } -func (p *packetBuffers) sendOldestPacket() *packet { +func (p *PacketBuffers) sendOldestPacket() *Packet { for e := p.buffers.Front(); e != nil; e = e.Next() { - packet := e.Value.(*packet) - if packet.packet.Header().SequenceNumber == p.oldestPacket.packet.Header().SequenceNumber { + packet := e.Value.(*Packet) + if packet.Packet.Header().SequenceNumber == p.oldestPacket.Packet.Header().SequenceNumber { return p.pop(e) } } @@ -255,19 +254,19 @@ func (p *packetBuffers) sendOldestPacket() *packet { return nil } -func (p *packetBuffers) fetch(e *list.Element) *packet { +func (p *PacketBuffers) fetch(e *list.Element) *Packet { p.latencyMu.RLock() maxLatency := p.maxLatency minLatency := p.minLatency p.latencyMu.RUnlock() - currentPacket := e.Value.(*packet) + currentPacket := e.Value.(*Packet) - currentSeq := currentPacket.packet.Header().SequenceNumber + currentSeq := currentPacket.Packet.Header().SequenceNumber latency := time.Since(currentPacket.addedTime) - if !p.Initiated() && latency > minLatency && e.Next() != nil && !IsRTPPacketLate(e.Next().Value.(*rtppool.RetainablePacket).Header().SequenceNumber, currentSeq) { + if !p.Initiated() && latency > minLatency && e.Next() != nil && !IsRTPPacketLate(e.Next().Value.(*Packet).Packet.Header().SequenceNumber, currentSeq) { // first packet to send, but make sure we have the packet in order p.mu.Lock() p.initSequence = currentSeq @@ -294,14 +293,14 @@ func (p *packetBuffers) fetch(e *list.Element) *packet { // p.log.Infof("packet latency: ", packetLatency, " gap: ", gap, " currentSeq: ", currentSeq, " nextSeq: ", nextSeq) if latency > maxLatency { // we have waited too long, we should send the packets - p.log.Warnf("packet cache: packet sequence ", currentPacket.packet.Header().SequenceNumber, " latency ", latency, ", reached max latency ", maxLatency, ", will sending the packets") + p.log.Warnf("packet cache: packet sequence ", currentPacket.Packet.Header().SequenceNumber, " latency ", latency, ", reached max latency ", maxLatency, ", will sending the packets") return p.pop(e) } return nil } -func (p *packetBuffers) Pop() *packet { +func (p *PacketBuffers) Pop() *Packet { p.mu.RLock() if p.oldestPacket != nil && time.Since(p.oldestPacket.addedTime) > p.maxLatency { p.mu.RUnlock() @@ -316,16 +315,16 @@ func (p *packetBuffers) Pop() *packet { return nil } - item := frontElement.Value.(*rtppool.RetainablePacket) - if err := item.Retain(); err != nil { + item := frontElement.Value.(*Packet) + if err := item.Packet.Retain(); err != nil { return nil } defer func() { - item.Release() + item.Packet.Release() }() - if IsRTPPacketLate(item.Header().SequenceNumber, p.lastSequenceNumber) { + if IsRTPPacketLate(item.Packet.Header().SequenceNumber, p.lastSequenceNumber) { return p.pop(frontElement) } @@ -333,11 +332,11 @@ func (p *packetBuffers) Pop() *packet { return p.fetch(frontElement) } -func (p *packetBuffers) Flush() []*packet { +func (p *PacketBuffers) Flush() []*Packet { return p.flush() } -func (p *packetBuffers) Last() *packet { +func (p *PacketBuffers) Last() *Packet { p.mu.RLock() defer p.mu.RUnlock() @@ -345,23 +344,23 @@ func (p *packetBuffers) Last() *packet { return nil } - return p.buffers.Back().Value.(*packet) + return p.buffers.Back().Value.(*Packet) } -func (p *packetBuffers) Len() int { +func (p *PacketBuffers) Len() int { p.mu.RLock() defer p.mu.RUnlock() return p.buffers.Len() } -func (p *packetBuffers) Clear() { +func (p *PacketBuffers) Clear() { p.mu.Lock() defer p.mu.Unlock() for e := p.buffers.Front(); e != nil; e = e.Next() { - packet := e.Value.(*rtppool.RetainablePacket) - packet.Release() + packet := e.Value.(*Packet) + packet.Packet.Release() p.buffers.Remove(e) } @@ -370,7 +369,7 @@ func (p *packetBuffers) Clear() { p.oldestPacket = nil } -func (p *packetBuffers) Close() { +func (p *PacketBuffers) Close() { p.mu.Lock() p.ended = true p.mu.Unlock() @@ -380,7 +379,7 @@ func (p *packetBuffers) Close() { // make sure we don't have any waiters p.packetAvailableWait.Signal() } -func (p *packetBuffers) WaitAvailablePacket() { +func (p *PacketBuffers) WaitAvailablePacket() { p.mu.RLock() if p.buffers.Len() == 0 { p.mu.RUnlock() @@ -399,20 +398,20 @@ func (p *packetBuffers) WaitAvailablePacket() { p.packetAvailableWait.Wait() } -func (p *packetBuffers) checkOrderedPacketAndRecordTimes() { +func (p *PacketBuffers) checkOrderedPacketAndRecordTimes() { for e := p.buffers.Front(); e != nil; e = e.Next() { - pkt := e.Value.(*packet) + pkt := e.Value.(*Packet) // make sure call retain to prevent packet from being released when we are still using it - if err := pkt.packet.Retain(); err != nil { + if err := pkt.Packet.Retain(); err != nil { // already released continue } - currentSeq := pkt.packet.Header().SequenceNumber + currentSeq := pkt.Packet.Header().SequenceNumber latency := time.Since(pkt.addedTime) - if !p.init && latency > p.minLatency && e.Next() != nil && !IsRTPPacketLate(e.Next().Value.(*rtppool.RetainablePacket).Header().SequenceNumber, currentSeq) { + if !p.init && latency > p.minLatency && e.Next() != nil && !IsRTPPacketLate(e.Next().Value.(*Packet).Packet.Header().SequenceNumber, currentSeq) { // signal first packet to send p.packetAvailableWait.Signal() } else if (p.lastSequenceNumber < currentSeq || p.lastSequenceNumber-currentSeq > uint16SizeHalf) && currentSeq-p.lastSequenceNumber == 1 { @@ -429,17 +428,17 @@ func (p *packetBuffers) checkOrderedPacketAndRecordTimes() { } // release the packet after we are done with it - pkt.packet.Release() + pkt.Packet.Release() } } -func (p *packetBuffers) recordWaitTime(el *list.Element) { +func (p *PacketBuffers) recordWaitTime(el *list.Element) { p.waitTimeMu.Lock() defer p.waitTimeMu.Unlock() - pkt := el.Value.(*packet) + pkt := el.Value.(*Packet) - if p.lastSequenceWaitTime == pkt.packet.Header().SequenceNumber || IsRTPPacketLate(pkt.packet.Header().SequenceNumber, p.lastSequenceWaitTime) || p.previousAddedTime.IsZero() { + if p.lastSequenceWaitTime == pkt.Packet.Header().SequenceNumber || IsRTPPacketLate(pkt.Packet.Header().SequenceNumber, p.lastSequenceWaitTime) || p.previousAddedTime.IsZero() { // don't record late packet or already recorded packet if p.previousAddedTime.IsZero() { p.previousAddedTime = pkt.addedTime @@ -448,7 +447,7 @@ func (p *packetBuffers) recordWaitTime(el *list.Element) { return } - p.lastSequenceWaitTime = pkt.packet.Header().SequenceNumber + p.lastSequenceWaitTime = pkt.Packet.Header().SequenceNumber // remove oldest packet from the wait times if more than 500 if uint16(len(p.waitTimes)+1) > waitTimeSize { @@ -462,7 +461,7 @@ func (p *packetBuffers) recordWaitTime(el *list.Element) { } -func (p *packetBuffers) checkWaitTimeAdjuster() { +func (p *PacketBuffers) checkWaitTimeAdjuster() { if p.waitTimeResetCounter > waitTimeSize { p.waitTimeMu.Lock() defer p.waitTimeMu.Unlock() diff --git a/packetbuffers_test.go b/packetbuffers_test.go index 90c5fe8..f470b94 100644 --- a/packetbuffers_test.go +++ b/packetbuffers_test.go @@ -37,7 +37,7 @@ func TestAdd(t *testing.T) { minLatency := 10 * time.Millisecond maxLatency := 100 * time.Millisecond - caches := newPacketBuffers(ctx, minLatency, maxLatency, false, logging.NewDefaultLoggerFactory().NewLogger("sfu")) + caches := NewPacketBuffers(ctx, minLatency, maxLatency, false, logging.NewDefaultLoggerFactory().NewLogger("sfu")) for i, pkt := range packets { rp := pool.NewPacket(&pkt.Header, pkt.Payload) @@ -77,7 +77,7 @@ func TestAddLost(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - caches := newPacketBuffers(ctx, minLatency, maxLatency, false, logging.NewDefaultLoggerFactory().NewLogger("sfu")) + caches := NewPacketBuffers(ctx, minLatency, maxLatency, false, logging.NewDefaultLoggerFactory().NewLogger("sfu")) for i, pkt := range packets { if pkt.SequenceNumber == 65533 { @@ -126,7 +126,7 @@ func TestDuplicateAdd(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - caches := newPacketBuffers(ctx, minLatency, maxLatency, false, logging.NewDefaultLoggerFactory().NewLogger("sfu")) + caches := NewPacketBuffers(ctx, minLatency, maxLatency, false, logging.NewDefaultLoggerFactory().NewLogger("sfu")) for i, pkt := range packets { if i == 9 { @@ -152,7 +152,7 @@ func TestDuplicateAdd(t *testing.T) { // i := 0 // for e := caches.buffers.Front(); e != nil; e = e.Next() { - // packet := e.Value.(*packet) + // packet := e.Value.(*Packet) // require.Equal(t, packet.RTP.Header().SequenceNumber, sortedNumbers[i], fmt.Sprintf("packet sequence number %d should be equal to sortedNumbers sequence number %d", packet.RTP.Header().SequenceNumber, sortedNumbers[i])) // i++ // } @@ -176,7 +176,7 @@ func TestFlush(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - caches := newPacketBuffers(ctx, minLatency, maxLatency, false, logging.NewDefaultLoggerFactory().NewLogger("sfu")) + caches := NewPacketBuffers(ctx, minLatency, maxLatency, false, logging.NewDefaultLoggerFactory().NewLogger("sfu")) for i, pkt := range packets { rp := pool.NewPacket(&pkt.Header, pkt.Payload) @@ -218,9 +218,9 @@ func TestFlushBetweenAdded(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - caches := newPacketBuffers(ctx, minLatency, maxLatency, false, logging.NewDefaultLoggerFactory().NewLogger("sfu")) + caches := NewPacketBuffers(ctx, minLatency, maxLatency, false, logging.NewDefaultLoggerFactory().NewLogger("sfu")) - sorted := make([]*packet, 0) + sorted := make([]*Packet, 0) for i, pkt := range packets { rp := pool.NewPacket(&pkt.Header, pkt.Payload) @@ -266,9 +266,9 @@ func TestLatency(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - caches := newPacketBuffers(ctx, minLatency, maxLatency, false, logging.NewDefaultLoggerFactory().NewLogger("sfu")) + caches := NewPacketBuffers(ctx, minLatency, maxLatency, false, logging.NewDefaultLoggerFactory().NewLogger("sfu")) - sorted := make([]*packet, 0) + sorted := make([]*Packet, 0) seqs := make([]uint16, 0) resultsSeqs := make([]uint16, 0) dropped := 0 @@ -339,7 +339,7 @@ func BenchmarkPushPool(b *testing.B) { } } - packetBuffers := newPacketBuffers(ctx, 10*time.Millisecond, 100*time.Millisecond, false, logging.NewDefaultLoggerFactory().NewLogger("sfu")) + packetBuffers := NewPacketBuffers(ctx, 10*time.Millisecond, 100*time.Millisecond, false, logging.NewDefaultLoggerFactory().NewLogger("sfu")) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -353,7 +353,7 @@ func BenchmarkPushPool(b *testing.B) { func BenchmarkPopPool(b *testing.B) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - packetBuffers := newPacketBuffers(ctx, 10*time.Millisecond, 100*time.Millisecond, false, logging.NewDefaultLoggerFactory().NewLogger("sfu")) + packetBuffers := NewPacketBuffers(ctx, 10*time.Millisecond, 100*time.Millisecond, false, logging.NewDefaultLoggerFactory().NewLogger("sfu")) for i := 0; i < b.N; i++ { rp := pool.NewPacket(&rtp.Header{}, make([]byte, 1400)) diff --git a/remotetrack.go b/remotetrack.go index 523bf6e..cc84caa 100644 --- a/remotetrack.go +++ b/remotetrack.go @@ -11,10 +11,10 @@ import ( "github.com/inlivedev/sfu/pkg/networkmonitor" "github.com/inlivedev/sfu/pkg/rtppool" + "github.com/pion/interceptor" "github.com/pion/interceptor/pkg/stats" "github.com/pion/logging" "github.com/pion/rtp" - "github.com/pion/webrtc/v4" ) type remoteTrack struct { @@ -22,7 +22,7 @@ type remoteTrack struct { cancel context.CancelFunc mu sync.RWMutex track IRemoteTrack - onRead func(*rtp.Packet) + onRead func(interceptor.Attributes, *rtp.Packet) onPLI func() bitrate *atomic.Uint32 previousBytesReceived *atomic.Uint64 @@ -32,13 +32,11 @@ type remoteTrack struct { onEndedCallbacks []func() statsGetter stats.Getter onStatsUpdated func(*stats.Stats) - packetBuffers *packetBuffers - looping bool log logging.LeveledLogger rtppool *rtppool.RTPPool } -func newRemoteTrack(ctx context.Context, log logging.LeveledLogger, useBuffer bool, track IRemoteTrack, minWait, maxWait, pliInterval time.Duration, onPLI func(), statsGetter stats.Getter, onStatsUpdated func(*stats.Stats), onRead func(*rtp.Packet), pool *rtppool.RTPPool, onNetworkConditionChanged func(networkmonitor.NetworkConditionType)) *remoteTrack { +func newRemoteTrack(ctx context.Context, log logging.LeveledLogger, useBuffer bool, track IRemoteTrack, minWait, maxWait, pliInterval time.Duration, onPLI func(), statsGetter stats.Getter, onStatsUpdated func(*stats.Stats), onRead func(interceptor.Attributes, *rtp.Packet), pool *rtppool.RTPPool, onNetworkConditionChanged func(networkmonitor.NetworkConditionType)) *remoteTrack { localctx, cancel := context.WithCancel(ctx) rt := &remoteTrack{ @@ -59,27 +57,15 @@ func newRemoteTrack(ctx context.Context, log logging.LeveledLogger, useBuffer bo rtppool: pool, } - if useBuffer && track.Kind() == webrtc.RTPCodecTypeVideo { - rt.packetBuffers = newPacketBuffers(localctx, minWait, maxWait, true, log) - } - if pliInterval > 0 { rt.enableIntervalPLI(pliInterval) } go rt.readRTP() - if useBuffer && track.Kind() == webrtc.RTPCodecTypeVideo { - go rt.loop() - } - return rt } -func (t *remoteTrack) Buffered() bool { - return t.packetBuffers != nil -} - func (t *remoteTrack) Context() context.Context { return t.context } @@ -104,7 +90,7 @@ func (t *remoteTrack) readRTP() { } buffer := t.rtppool.GetPayload() - n, _, readErr := t.track.Read(*buffer) + n, attrs, readErr := t.track.Read(*buffer) if readErr != nil { if readErr == io.EOF { t.log.Infof("remotetrack: track ended %s ", t.track.ID()) @@ -136,13 +122,7 @@ func (t *remoteTrack) readRTP() { go t.updateStats() } - if t.Buffered() && t.Track().Kind() == webrtc.RTPCodecTypeVideo { - retainablePacket := t.rtppool.NewPacket(&p.Header, p.Payload) - _ = t.packetBuffers.Add(retainablePacket) - - } else { - t.onRead(p) - } + t.onRead(attrs, p) t.rtppool.PutPayload(buffer) t.rtppool.PutPacket(p) @@ -170,71 +150,6 @@ func (t *remoteTrack) unmarshal(buf []byte, p *rtp.Packet) error { return nil } -func (t *remoteTrack) loop() { - if t.looping { - return - } - - ctx, cancel := context.WithCancel(t.context) - defer cancel() - - t.mu.Lock() - t.looping = true - t.mu.Unlock() - - defer func() { - t.mu.Lock() - t.looping = false - t.mu.Unlock() - }() - - for { - select { - case <-ctx.Done(): - return - default: - t.packetBuffers.WaitAvailablePacket() - t.mu.RLock() - for orderedPkt := t.packetBuffers.Pop(); orderedPkt != nil; orderedPkt = t.packetBuffers.Pop() { - // make sure the we're passing a new packet to the onRead callback - - copyPkt := t.rtppool.GetPacket() - - copyPkt.Header = *orderedPkt.packet.Header() - - copyPkt.Payload = orderedPkt.packet.Payload() - - t.onRead(copyPkt) - - t.rtppool.PutPacket(copyPkt) - - orderedPkt.packet.Release() - } - - t.mu.RUnlock() - - } - } - -} - -func (t *remoteTrack) Flush() { - pkts := t.packetBuffers.Flush() - for _, pkt := range pkts { - copyPkt := t.rtppool.GetPacket() - - copyPkt.Header = *pkt.packet.Header() - - copyPkt.Payload = pkt.packet.Payload() - - t.onRead(copyPkt) - - t.rtppool.PutPacket(copyPkt) - - pkt.packet.Release() - } -} - func (t *remoteTrack) updateStats() { s := t.statsGetter.Get(uint32(t.track.SSRC())) if s == nil { @@ -304,13 +219,6 @@ func (t *remoteTrack) IsRelay() bool { return ok } -func (t *remoteTrack) Buffer() *packetBuffers { - t.mu.RLock() - defer t.mu.RUnlock() - - return t.packetBuffers -} - func (t *remoteTrack) OnEnded(f func()) { t.mu.Lock() defer t.mu.Unlock() diff --git a/sfu.go b/sfu.go index c666a4d..3e8a507 100644 --- a/sfu.go +++ b/sfu.go @@ -309,6 +309,10 @@ func (s *SFU) GetClient(id string) (*Client, error) { return s.clients.GetClient(id) } +func (s *SFU) GetClients() map[string]*Client { + return s.clients.GetClients() +} + func (s *SFU) removeClient(client *Client) error { if err := s.clients.Remove(client); err != nil { s.log.Errorf("sfu: failed to remove client ", err) @@ -389,7 +393,7 @@ func (s *SFU) createExistingDataChannels(c *Client) { } if err := c.createDataChannel(dc.label, initOpts); err != nil { - s.log.Errorf("datachanel: error on create existing data channels, ", err) + s.log.Errorf("datachanel: error on create existing data channel %s, error %s", dc.label, err.Error()) } } } diff --git a/track.go b/track.go index ac0127b..f82c320 100644 --- a/track.go +++ b/track.go @@ -10,6 +10,7 @@ import ( "github.com/inlivedev/sfu/pkg/interceptors/voiceactivedetector" "github.com/inlivedev/sfu/pkg/networkmonitor" "github.com/inlivedev/sfu/pkg/rtppool" + "github.com/pion/interceptor" "github.com/pion/interceptor/pkg/stats" "github.com/pion/logging" "github.com/pion/rtp" @@ -55,14 +56,14 @@ type ITrack interface { SetSourceType(TrackType) SourceType() TrackType SetAsProcessed() - OnRead(func(*rtp.Packet, QualityLevel)) + OnRead(func(interceptor.Attributes, *rtp.Packet, QualityLevel)) IsScreen() bool IsRelay() bool Kind() webrtc.RTPCodecType MimeType() string TotalTracks() int Context() context.Context - Relay(func(webrtc.SSRC, *rtp.Packet)) + Relay(func(webrtc.SSRC, interceptor.Attributes, *rtp.Packet)) PayloadType() webrtc.PayloadType OnEnded(func()) } @@ -73,7 +74,7 @@ type Track struct { base *baseTrack remoteTrack *remoteTrack onEndedCallbacks []func() - onReadCallbacks []func(*rtp.Packet, QualityLevel) + onReadCallbacks []func(interceptor.Attributes, *rtp.Packet, QualityLevel) } type AudioTrack struct { @@ -100,11 +101,11 @@ func newTrack(ctx context.Context, client *Client, trackRemote IRemoteTrack, min t := &Track{ mu: sync.Mutex{}, base: baseTrack, - onReadCallbacks: make([]func(*rtp.Packet, QualityLevel), 0), + onReadCallbacks: make([]func(interceptor.Attributes, *rtp.Packet, QualityLevel), 0), onEndedCallbacks: make([]func(), 0), } - onRead := func(p *rtp.Packet) { + onRead := func(attrs interceptor.Attributes, p *rtp.Packet) { tracks := t.base.clientTracks.GetTracks() for _, track := range tracks { @@ -129,7 +130,7 @@ func newTrack(ctx context.Context, client *Client, trackRemote IRemoteTrack, min copyPacket.Header = *packet.Header() copyPacket.Payload = packet.Payload() - t.onRead(copyPacket, QualityHigh) + t.onRead(attrs, copyPacket, QualityHigh) pool.PutPacket(copyPacket) @@ -336,29 +337,32 @@ func (t *Track) SetAsProcessed() { t.base.isProcessed = true } -func (t *Track) OnRead(callback func(*rtp.Packet, QualityLevel)) { +func (t *Track) OnRead(callback func(interceptor.Attributes, *rtp.Packet, QualityLevel)) { t.mu.Lock() defer t.mu.Unlock() t.onReadCallbacks = append(t.onReadCallbacks, callback) } -func (t *Track) onRead(p *rtp.Packet, quality QualityLevel) { +func (t *Track) onRead(attrs interceptor.Attributes, p *rtp.Packet, quality QualityLevel) { + callbacks := make([]func(interceptor.Attributes, *rtp.Packet, QualityLevel), 0) + t.mu.Lock() - defer t.mu.Unlock() + callbacks = append(callbacks, t.onReadCallbacks...) + t.mu.Unlock() - for _, callback := range t.onReadCallbacks { + for _, callback := range callbacks { copyPacket := t.base.pool.GetPacket() copyPacket.Header = p.Header copyPacket.Payload = p.Payload - callback(p, quality) + callback(attrs, p, quality) t.base.pool.PutPacket(copyPacket) } } -func (t *Track) Relay(f func(webrtc.SSRC, *rtp.Packet)) { - t.OnRead(func(p *rtp.Packet, quality QualityLevel) { - f(t.SSRC(), p) +func (t *Track) Relay(f func(webrtc.SSRC, interceptor.Attributes, *rtp.Packet)) { + t.OnRead(func(attrs interceptor.Attributes, p *rtp.Packet, quality QualityLevel) { + f(t.SSRC(), attrs, p) }) } @@ -412,7 +416,7 @@ type SimulcastTrack struct { lastMidKeyframeTS *atomic.Int64 lastLowKeyframeTS *atomic.Int64 onAddedRemoteTrackCallbacks []func(*remoteTrack) - onReadCallbacks []func(*rtp.Packet, QualityLevel) + onReadCallbacks []func(interceptor.Attributes, *rtp.Packet, QualityLevel) pliInterval time.Duration onNetworkConditionChanged func(networkmonitor.NetworkConditionType) reordered bool @@ -442,7 +446,7 @@ func newSimulcastTrack(client *Client, track IRemoteTrack, minWait, maxWait, pli lastLowKeyframeTS: &atomic.Int64{}, onTrackCompleteCallbacks: make([]func(), 0), onAddedRemoteTrackCallbacks: make([]func(*remoteTrack), 0), - onReadCallbacks: make([]func(*rtp.Packet, QualityLevel), 0), + onReadCallbacks: make([]func(interceptor.Attributes, *rtp.Packet, QualityLevel), 0), pliInterval: pliInterval, onNetworkConditionChanged: func(condition networkmonitor.NetworkConditionType) { client.onNetworkConditionChanged(condition) @@ -535,7 +539,7 @@ func (t *SimulcastTrack) AddRemoteTrack(track IRemoteTrack, minWait, maxWait tim quality := RIDToQuality(track.RID()) - onRead := func(p *rtp.Packet) { + onRead := func(attrs interceptor.Attributes, p *rtp.Packet) { // set the base timestamp for the track if it is not set yet if t.baseTS == 0 { @@ -590,7 +594,7 @@ func (t *SimulcastTrack) AddRemoteTrack(track IRemoteTrack, minWait, maxWait tim copyPacket.Header = *packet.Header() copyPacket.Payload = packet.Payload() - t.onRead(copyPacket, quality) + t.onRead(attrs, copyPacket, quality) t.base.pool.PutPacket(copyPacket) @@ -808,16 +812,16 @@ func (t *SimulcastTrack) MimeType() string { return t.base.codec.MimeType } -func (t *SimulcastTrack) OnRead(callback func(*rtp.Packet, QualityLevel)) { +func (t *SimulcastTrack) OnRead(callback func(interceptor.Attributes, *rtp.Packet, QualityLevel)) { t.mu.Lock() defer t.mu.Unlock() t.onReadCallbacks = append(t.onReadCallbacks, callback) } -func (t *SimulcastTrack) onRead(p *rtp.Packet, quality QualityLevel) { +func (t *SimulcastTrack) onRead(attr interceptor.Attributes, p *rtp.Packet, quality QualityLevel) { for _, callback := range t.onReadCallbacks { - callback(p, quality) + callback(attr, p, quality) } } @@ -887,15 +891,15 @@ func (t *SimulcastTrack) RIDLow() string { return t.remoteTrackLow.track.RID() } -func (t *SimulcastTrack) Relay(f func(webrtc.SSRC, *rtp.Packet)) { - t.OnRead(func(p *rtp.Packet, quality QualityLevel) { +func (t *SimulcastTrack) Relay(f func(webrtc.SSRC, interceptor.Attributes, *rtp.Packet)) { + t.OnRead(func(attrs interceptor.Attributes, p *rtp.Packet, quality QualityLevel) { switch quality { case QualityHigh: - f(t.SSRCHigh(), p) + f(t.SSRCHigh(), attrs, p) case QualityMid: - f(t.SSRCMid(), p) + f(t.SSRCMid(), attrs, p) case QualityLow: - f(t.SSRCLow(), p) + f(t.SSRCLow(), attrs, p) } }) }