From f9d923acad628558eb7a6bd2faea313763ca2e6a Mon Sep 17 00:00:00 2001 From: Yohan Totting Date: Mon, 22 Apr 2024 15:25:42 +0700 Subject: [PATCH] fix: voice active detection --- client.go | 26 ++++++++-- clientstats.go | 3 ++ examples/http-websocket/index.html | 15 +++--- .../voiceactivedetector/interceptor.go | 52 ++++++++++++------- pkg/interceptors/voiceactivedetector/vad.go | 23 ++++---- track.go | 12 +---- 6 files changed, 75 insertions(+), 56 deletions(-) diff --git a/client.go b/client.go index 4d5a98e..5cac258 100644 --- a/client.go +++ b/client.go @@ -173,7 +173,7 @@ type Client struct { ingressBandwidth *atomic.Uint32 ingressQualityLimitationReason *atomic.Value isDebug bool - vad *voiceactivedetector.Interceptor + vadInterceptor *voiceactivedetector.Interceptor } func DefaultClientOptions() ClientOptions { @@ -192,8 +192,7 @@ func DefaultClientOptions() ClientOptions { func NewClient(s *SFU, id string, name string, peerConnectionConfig webrtc.Configuration, opts ClientOptions) *Client { var client *Client - - var vad *voiceactivedetector.Interceptor + var vadInterceptor *voiceactivedetector.Interceptor localCtx, cancel := context.WithCancel(s.context) m := &webrtc.MediaEngine{} @@ -233,7 +232,16 @@ func NewClient(s *SFU, id string, name string, peerConnectionConfig webrtc.Confi // enable voice detector vadInterceptorFactory.OnNew(func(i *voiceactivedetector.Interceptor) { - vad = i + vadInterceptor = i + i.OnNewVAD(func(vad *voiceactivedetector.VoiceDetector) { + glog.Info("track: voice activity detector enabled") + vad.OnVoiceDetected(func(activity voiceactivedetector.VoiceActivity) { + // send through datachannel + if client != nil { + client.onVoiceDetected(activity) + } + }) + }) }) i.Add(vadInterceptorFactory) @@ -335,7 +343,7 @@ func NewClient(s *SFU, id string, name string, peerConnectionConfig webrtc.Confi ingressBandwidth: &atomic.Uint32{}, ingressQualityLimitationReason: &atomic.Value{}, onTracksAvailableCallbacks: make([]func([]ITrack), 0), - vad: vad, + vadInterceptor: vadInterceptor, } // make sure the exisiting data channels is created on new clients @@ -908,6 +916,14 @@ func (c *Client) setClientTrack(t ITrack) iClientTrack { return nil } + if t.Kind() == webrtc.RTPCodecTypeAudio { + if c.IsVADEnabled() { + glog.Info("track: voice activity detector enabled") + ssrc := senderTcv.Sender().GetParameters().Encodings[0].SSRC + c.vadInterceptor.MapAudioTrack(uint32(ssrc), localTrack) + } + } + go func() { ctx, cancel := context.WithCancel(outputTrack.Context()) defer cancel() diff --git a/clientstats.go b/clientstats.go index 68a21a1..19ad4d6 100644 --- a/clientstats.go +++ b/clientstats.go @@ -182,6 +182,9 @@ func (c *ClientStats) SetReceiver(id, rid string, stats stats.Stats) { // UpdateVoiceActivity updates voice activity duration // 0 timestamp means ended func (c *ClientStats) UpdateVoiceActivity(ts uint32) { + c.mu.Lock() + defer c.mu.Unlock() + if !c.voiceActivity.active && ts != 0 { c.voiceActivity.active = true c.voiceActivity.start = ts diff --git a/examples/http-websocket/index.html b/examples/http-websocket/index.html index a007a6b..bd2124f 100644 --- a/examples/http-websocket/index.html +++ b/examples/http-websocket/index.html @@ -132,6 +132,8 @@ let negotiationNeeded = false + const trackStreamIDs = {} + peerConnection.ondatachannel = function(e) { console.log("ondatachannel: ",e.channel.label); if (e.channel.label == "internal") { @@ -245,7 +247,9 @@ }; function updateVoiceDetected(vad){ - const videoEl = document.getElementById("video-"+vad.data.stream_id) + const streamid = vad.data.streamID + + const videoEl = document.getElementById("video-"+streamid) if (!videoEl){ return } @@ -321,20 +325,19 @@ e.streams.forEach((stream) => { console.log("ontrack", stream, e.track); - const streamid = stream.id.replace('{','').replace('}',''); - let container = document.getElementById("container-"+streamid); + let container = document.getElementById("container-"+stream.id); if (!container) { container = document.createElement("div"); container.className = "container"; - container.id = "container-"+streamid; + container.id = "container-"+stream.id; document.querySelector('main').appendChild(container); } - let remoteVideo = document.getElementById("video-"+streamid); + let remoteVideo = document.getElementById("video-"+stream.id); if (!remoteVideo) { remoteVideo = document.createElement("video"); - remoteVideo.id = "video-"+streamid; + remoteVideo.id = "video-"+stream.id; remoteVideo.autoplay = true; container.appendChild(remoteVideo); if (videoObserver!=null){ diff --git a/pkg/interceptors/voiceactivedetector/interceptor.go b/pkg/interceptors/voiceactivedetector/interceptor.go index 1b8b459..a492041 100644 --- a/pkg/interceptors/voiceactivedetector/interceptor.go +++ b/pkg/interceptors/voiceactivedetector/interceptor.go @@ -57,8 +57,9 @@ func DefaultConfig() Config { type Interceptor struct { context context.Context mu sync.RWMutex - vads map[string]*VoiceDetector + vads map[uint32]*VoiceDetector config Config + onNew func(vad *VoiceDetector) } func new(ctx context.Context) *Interceptor { @@ -66,7 +67,7 @@ func new(ctx context.Context) *Interceptor { context: ctx, mu: sync.RWMutex{}, config: DefaultConfig(), - vads: make(map[string]*VoiceDetector), + vads: make(map[uint32]*VoiceDetector), } } @@ -77,6 +78,13 @@ func (v *Interceptor) SetConfig(config Config) { v.config = config } +func (v *Interceptor) OnNewVAD(callback func(vad *VoiceDetector)) { + v.mu.Lock() + defer v.mu.Unlock() + + v.onNew = callback +} + // BindLocalStream lets you modify any outgoing RTP packets. It is called once for per LocalStream. The returned method // will be called once per rtp packet. func (v *Interceptor) BindLocalStream(info *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter { @@ -84,7 +92,7 @@ func (v *Interceptor) BindLocalStream(info *interceptor.StreamInfo, writer inter return writer } - vad := v.getVadByID(info.ID) + vad := v.getVadBySSRC(info.SSRC) if vad != nil { vad.updateStreamInfo(info) } @@ -93,19 +101,23 @@ func (v *Interceptor) BindLocalStream(info *interceptor.StreamInfo, writer inter defer v.mu.Unlock() if vad == nil { - v.vads[info.ID] = newVAD(v.context, v, info) + v.vads[info.SSRC] = newVAD(v.context, v, info) + vad = v.vads[info.SSRC] + } + if v.onNew != nil { + v.onNew(vad) } return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { - _ = v.processPacket(info.ID, header) + _ = v.processPacket(info.SSRC, header) return writer.Write(header, payload, attributes) }) } // UnbindLocalStream is called when the Stream is removed. It can be used to clean up any data related to that track. func (v *Interceptor) UnbindLocalStream(info *interceptor.StreamInfo) { - vad := v.getVadByID(info.ID) + vad := v.getVadBySSRC(info.SSRC) if vad != nil { vad.Stop() } @@ -113,7 +125,7 @@ func (v *Interceptor) UnbindLocalStream(info *interceptor.StreamInfo) { v.mu.Lock() defer v.mu.Unlock() - delete(v.vads, info.ID) + delete(v.vads, info.SSRC) } @@ -144,11 +156,11 @@ func (v *Interceptor) BindRTCPWriter(writer interceptor.RTCPWriter) interceptor. return writer } -func (v *Interceptor) getVadByID(id string) *VoiceDetector { +func (v *Interceptor) getVadBySSRC(ssrc uint32) *VoiceDetector { v.mu.RLock() defer v.mu.RUnlock() - vad, ok := v.vads[id] + vad, ok := v.vads[ssrc] if ok { return vad } @@ -156,15 +168,15 @@ func (v *Interceptor) getVadByID(id string) *VoiceDetector { return nil } -func (v *Interceptor) processPacket(id string, header *rtp.Header) rtp.AudioLevelExtension { - audioData := v.getAudioLevel(id, header) +func (v *Interceptor) processPacket(ssrc uint32, header *rtp.Header) rtp.AudioLevelExtension { + audioData := v.getAudioLevel(ssrc, header) if audioData.Level == 0 { return rtp.AudioLevelExtension{} } - vad := v.getVadByID(id) + vad := v.getVadBySSRC(ssrc) if vad == nil { - glog.Error("vad: not found vad for track id", id) + glog.Error("vad: not found vad for track ssrc", ssrc) return rtp.AudioLevelExtension{} } @@ -180,9 +192,9 @@ func (v *Interceptor) getConfig() Config { return v.config } -func (v *Interceptor) getAudioLevel(id string, header *rtp.Header) rtp.AudioLevelExtension { +func (v *Interceptor) getAudioLevel(ssrc uint32, header *rtp.Header) rtp.AudioLevelExtension { audioLevel := rtp.AudioLevelExtension{} - headerID := v.getAudioLevelExtensionID(id) + headerID := v.getAudioLevelExtensionID(ssrc) if headerID != 0 { ext := header.GetExtension(headerID) _ = audioLevel.Unmarshal(ext) @@ -197,8 +209,8 @@ func RegisterAudioLevelHeaderExtension(m *webrtc.MediaEngine) { } } -func (v *Interceptor) getAudioLevelExtensionID(id string) uint8 { - vad := v.getVadByID(id) +func (v *Interceptor) getAudioLevelExtensionID(ssrc uint32) uint8 { + vad := v.getVadBySSRC(ssrc) if vad != nil { for _, extension := range vad.streamInfo.RTPHeaderExtensions { if extension.URI == sdp.AudioLevelURI { @@ -211,17 +223,17 @@ func (v *Interceptor) getAudioLevelExtensionID(id string) uint8 { } // AddAudioTrack adds audio track to interceptor -func (v *Interceptor) AddAudioTrack(t webrtc.TrackLocal) *VoiceDetector { +func (v *Interceptor) MapAudioTrack(ssrc uint32, t webrtc.TrackLocal) *VoiceDetector { if t.Kind() != webrtc.RTPCodecTypeAudio { glog.Error("vad: track is not audio track") return nil } - vad := v.getVadByID(t.ID()) + vad := v.getVadBySSRC(ssrc) if vad == nil { vad = newVAD(v.context, v, nil) v.mu.Lock() - v.vads[t.ID()] = vad + v.vads[ssrc] = vad v.mu.Unlock() } diff --git a/pkg/interceptors/voiceactivedetector/vad.go b/pkg/interceptors/voiceactivedetector/vad.go index ae4751c..02e0acc 100644 --- a/pkg/interceptors/voiceactivedetector/vad.go +++ b/pkg/interceptors/voiceactivedetector/vad.go @@ -36,7 +36,7 @@ type VoiceDetector struct { channel chan VoicePacketData mu sync.RWMutex VoicePackets []VoicePacketData - callbacks []func(activity VoiceActivity) + callback func(activity VoiceActivity) } func newVAD(ctx context.Context, i *Interceptor, streamInfo *interceptor.StreamInfo) *VoiceDetector { @@ -47,7 +47,6 @@ func newVAD(ctx context.Context, i *Interceptor, streamInfo *interceptor.StreamI channel: make(chan VoicePacketData), mu: sync.RWMutex{}, VoicePackets: make([]VoicePacketData, 0), - callbacks: make([]func(VoiceActivity), 0), } v.run() @@ -113,8 +112,7 @@ loop: } func (v *VoiceDetector) sendPacketsToCallback() { - noCallbacks := len(v.callbacks) == 0 - if noCallbacks { + if v.callback == nil { return } @@ -147,11 +145,11 @@ func (v *VoiceDetector) getPackets() []VoicePacketData { } func (v *VoiceDetector) onVoiceDetected(activity VoiceActivity) { - v.mu.Lock() - defer v.mu.Unlock() + v.mu.RLock() + defer v.mu.RUnlock() - for _, callback := range v.callbacks { - callback(activity) + if v.callback != nil { + v.callback(activity) } } @@ -160,9 +158,9 @@ func (v *VoiceDetector) OnVoiceDetected(callback func(VoiceActivity)) { return } - v.mu.RLock() - defer v.mu.RUnlock() - v.callbacks = append(v.callbacks, callback) + v.mu.Lock() + defer v.mu.Unlock() + v.callback = callback } func (v *VoiceDetector) isDetected(vp VoicePacketData) bool { @@ -249,7 +247,4 @@ func (v *VoiceDetector) updateStreamInfo(streamInfo *interceptor.StreamInfo) { defer v.mu.RUnlock() v.streamInfo = streamInfo - if streamInfo.ID != "" { - v.trackID = streamInfo.ID - } } diff --git a/track.go b/track.go index cf54930..551d084 100644 --- a/track.go +++ b/track.go @@ -8,7 +8,6 @@ import ( "time" "github.com/golang/glog" - "github.com/inlivedev/sfu/pkg/interceptors/voiceactivedetector" "github.com/inlivedev/sfu/pkg/networkmonitor" "github.com/inlivedev/sfu/pkg/rtppool" "github.com/pion/interceptor/pkg/stats" @@ -254,16 +253,7 @@ func (t *Track) subscribe(c *Client) iClientTrack { } - if t.Kind() == webrtc.RTPCodecTypeAudio { - if c.IsVADEnabled() { - glog.Info("track: voice activity detector enabled") - vad := c.vad.AddAudioTrack(ct.LocalTrack()) - vad.OnVoiceDetected(func(activity voiceactivedetector.VoiceActivity) { - // send through datachannel - c.onVoiceDetected(activity) - }) - } - } else if t.Kind() == webrtc.RTPCodecTypeVideo { + if t.Kind() == webrtc.RTPCodecTypeVideo { t.remoteTrack.sendPLI() }