From 5fe9dbd863940a2c4412bcea284bc0d2992f97c9 Mon Sep 17 00:00:00 2001 From: Yohan Totting Date: Sun, 14 Jul 2024 22:12:56 +0700 Subject: [PATCH] fix vad --- .vscode/launch.json | 11 ++- examples/http-websocket/index.html | 20 +++- examples/http-websocket/main.go | 11 ++- .../voiceactivedetector/interceptor.go | 9 +- .../voiceactivedetector/packetmanager.go | 3 +- pkg/interceptors/voiceactivedetector/vad.go | 97 +++++++++++++------ .../voiceactivedetector/vad_test.go | 4 +- 7 files changed, 112 insertions(+), 43 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 3a36e65..890266c 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -11,9 +11,14 @@ "mode": "auto", "program": "${workspaceFolder}/examples/http-websocket/main.go", "env": { - // "PIONS_LOG_DEBUG":"all", - // "PIONS_LOG_TRACE":"all", - // "PIONS_LOG_INFO":"all", + "PIONS_LOG_DEBUG":"sfu,vad", + "PIONS_LOG_TRACE":"sfu,vad", + "PIONS_LOG_INFO":"sfu,vad", + "PIONS_LOG_WARN":"sfu,vad", + "PIONS_LOG_ERROR":"sfu,vad", + "stderrthreshold":"DEBUG", + "logtostderr":"true", + }, "buildFlags": "-race" diff --git a/examples/http-websocket/index.html b/examples/http-websocket/index.html index 3e00b4b..cd355f3 100644 --- a/examples/http-websocket/index.html +++ b/examples/http-websocket/index.html @@ -275,6 +275,7 @@ const videoEl = document.getElementById("video-"+streamid) if (!videoEl){ + console.log("video element not found ",streamid) return } @@ -287,7 +288,24 @@ videoEl.style.border = "5px solid green" } - } + let vadEl = document.getElementById("vad-"+streamid) + if (!vadEl){ + vadEl = document.createElement("div"); + vadEl.id = "vad-"+streamid + const container = document.getElementById("container-"+streamid) + container.appendChild(vadEl) + } + + if (vad.data.audioLevels!==null){ + vadEl.innerText = Math.floor(vad.data.audioLevels.reduce(function (sum, value) { + return sum + value.audioLevel; + }, 0) / vad.data.audioLevels.length); + + } else { + vadEl.innerText = "0" + } + } + function startH264() { start('h264') diff --git a/examples/http-websocket/main.go b/examples/http-websocket/main.go index 74a983b..4fea23e 100644 --- a/examples/http-websocket/main.go +++ b/examples/http-websocket/main.go @@ -71,10 +71,13 @@ var logger logging.LeveledLogger func main() { flag.Set("logtostderr", "true") - flag.Set("stderrthreshold", "INFO") - flag.Set("PIONS_LOG_INFO", "sfu") - flag.Set("PIONS_LOG_DEBUG", "sfu") - flag.Set("PIONS_LOG_TRACE", "sfu") + flag.Set("stderrthreshold", "DEBUG") + flag.Set("PIONS_LOG_INFO", "sfu,vad") + + flag.Set("PIONS_LOG_ERROR", "sfu,vad") + flag.Set("PIONS_LOG_WARN", "sfu,vad") + flag.Set("PIONS_LOG_DEBUG", "sfu,vad") + flag.Set("PIONS_LOG_TRACE", "sfu,vad") flag.Parse() diff --git a/pkg/interceptors/voiceactivedetector/interceptor.go b/pkg/interceptors/voiceactivedetector/interceptor.go index 40d88b5..46e66a0 100644 --- a/pkg/interceptors/voiceactivedetector/interceptor.go +++ b/pkg/interceptors/voiceactivedetector/interceptor.go @@ -105,7 +105,7 @@ func (v *Interceptor) BindLocalStream(info *interceptor.StreamInfo, writer inter defer v.mu.Unlock() if vad == nil { - v.vads[info.SSRC] = newVAD(v.context, v.config, info, v.log) + v.vads[info.SSRC] = newVAD(v.context, v.config, info) vad = v.vads[info.SSRC] } @@ -185,9 +185,12 @@ func (v *Interceptor) processPacket(ssrc uint32, header *rtp.Header) rtp.AudioLe return rtp.AudioLevelExtension{} } - vad.addPacket(header, audioData.Level) + if audioData.Voice { + vad.addPacket(header, audioData.Level, audioData.Voice) + } return audioData + } func (v *Interceptor) getConfig() Config { @@ -243,7 +246,7 @@ func (v *Interceptor) MapAudioTrack(ssrc uint32, t webrtc.TrackLocal) *VoiceDete vad := v.getVadBySSRC(ssrc) if vad == nil { - vad = newVAD(v.context, v.config, nil, v.log) + vad = newVAD(v.context, v.config, nil) v.mu.Lock() v.vads[ssrc] = vad v.mu.Unlock() diff --git a/pkg/interceptors/voiceactivedetector/packetmanager.go b/pkg/interceptors/voiceactivedetector/packetmanager.go index 57b41ff..c4fe1e7 100644 --- a/pkg/interceptors/voiceactivedetector/packetmanager.go +++ b/pkg/interceptors/voiceactivedetector/packetmanager.go @@ -24,7 +24,7 @@ func newPacketManager() *PacketManager { } } -func (m *PacketManager) NewPacket(seqNo uint16, timestamp uint32, audioLevel uint8) (*RetainablePacket, error) { +func (m *PacketManager) NewPacket(seqNo uint16, timestamp uint32, audioLevel uint8, isVoice bool) (*RetainablePacket, error) { p := &RetainablePacket{ onRelease: m.releasePacket, @@ -44,6 +44,7 @@ func (m *PacketManager) NewPacket(seqNo uint16, timestamp uint32, audioLevel uin p.data.SequenceNo = seqNo p.data.Timestamp = timestamp p.data.AudioLevel = audioLevel + p.data.IsVoice = isVoice return p, nil } diff --git a/pkg/interceptors/voiceactivedetector/vad.go b/pkg/interceptors/voiceactivedetector/vad.go index f772be2..7bf4242 100644 --- a/pkg/interceptors/voiceactivedetector/vad.go +++ b/pkg/interceptors/voiceactivedetector/vad.go @@ -14,6 +14,7 @@ type VoicePacketData struct { SequenceNo uint16 `json:"sequenceNo"` Timestamp uint32 `json:"timestamp"` AudioLevel uint8 `json:"audioLevel"` + IsVoice bool `json:"isVoice"` } type VoiceActivity struct { @@ -42,7 +43,7 @@ type VoiceDetector struct { log logging.LeveledLogger } -func newVAD(ctx context.Context, config Config, streamInfo *interceptor.StreamInfo, log logging.LeveledLogger) *VoiceDetector { +func newVAD(ctx context.Context, config Config, streamInfo *interceptor.StreamInfo) *VoiceDetector { v := &VoiceDetector{ context: ctx, config: config, @@ -51,7 +52,7 @@ func newVAD(ctx context.Context, config Config, streamInfo *interceptor.StreamIn mu: sync.RWMutex{}, VoicePackets: make([]*RetainablePacket, 0), packetManager: newPacketManager(), - log: log, + log: logging.NewDefaultLoggerFactory().NewLogger("vad"), } v.run() @@ -68,7 +69,7 @@ func newVAD(ctx context.Context, config Config, streamInfo *interceptor.StreamIn // once the tail margin close, stop send the packet. func (v *VoiceDetector) run() { go func() { - ticker := time.NewTicker(500 * time.Millisecond) + ticker := time.NewTicker(v.config.TailMargin) ctx, cancel := context.WithCancel(v.context) v.cancel = cancel @@ -77,23 +78,49 @@ func (v *VoiceDetector) run() { cancel() }() + active := false + lastSent := time.Now() + for { select { case <-ctx.Done(): return case voicePacket := <-v.channel: - - if v.isDetected(voicePacket) { + if voicePacket.data.AudioLevel < v.config.Threshold { // send all packets to callback - v.sendPacketsToCallback() + activity := VoiceActivity{ + TrackID: v.trackID, + StreamID: v.streamID, + SSRC: v.streamInfo.SSRC, + ClockRate: v.streamInfo.ClockRate, + AudioLevels: []*VoicePacketData{voicePacket.data}, + } + + v.onVoiceDetected(activity) + lastSent = time.Now() + active = true + + voicePacket.Release() } case <-ticker.C: - go v.dropExpiredPackets() + if active && time.Since(lastSent) > v.config.TailMargin { + // we need to notify that the voice is stopped + v.onVoiceDetected(VoiceActivity{ + TrackID: v.trackID, + StreamID: v.streamID, + SSRC: v.streamInfo.SSRC, + ClockRate: v.streamInfo.ClockRate, + AudioLevels: nil, + }) + active = false + } } } }() } +// TODO: this function is use together with isDetected function +// need to fix isDetected function first before we can use this function func (v *VoiceDetector) dropExpiredPackets() { loop: for { @@ -118,26 +145,35 @@ loop: } } -func (v *VoiceDetector) sendPacketsToCallback() { +func (v *VoiceDetector) sendPacketsToCallback() int { if v.callback == nil { - return + return 0 } // get all packets from head margin until tail margin packets := v.getPackets() - v.onVoiceDetected(VoiceActivity{ - TrackID: v.trackID, - StreamID: v.streamID, - SSRC: v.streamInfo.SSRC, - ClockRate: v.streamInfo.ClockRate, - AudioLevels: packets, - }) + length := len(packets) + + if length > 0 { + activity := VoiceActivity{ + TrackID: v.trackID, + StreamID: v.streamID, + SSRC: v.streamInfo.SSRC, + ClockRate: v.streamInfo.ClockRate, + AudioLevels: packets, + } + + v.onVoiceDetected(activity) + + v.log.Debugf("voice detected: %v", activity) + } // clear packets v.clearPackets() + return length } func (v *VoiceDetector) clearPackets() { @@ -157,7 +193,9 @@ func (v *VoiceDetector) getPackets() []*VoicePacketData { var packets []*VoicePacketData for _, packet := range v.VoicePackets { - packets = append(packets, packet.Data()) + if packet.Data().AudioLevel < v.config.Threshold { + packets = append(packets, packet.Data()) + } } return packets @@ -182,6 +220,9 @@ func (v *VoiceDetector) OnVoiceDetected(callback func(VoiceActivity)) { v.callback = callback } +// TODO: this function is sometimes stop detecting the voice activity +// we need to investigate why this is happening +// for now we're fallback to threshold based detection func (v *VoiceDetector) isDetected(vp *RetainablePacket) bool { v.mu.RLock() v.VoicePackets = append(v.VoicePackets, vp) @@ -198,19 +239,23 @@ func (v *VoiceDetector) isDetected(vp *RetainablePacket) bool { return v.detected } - isHeadMarginPassed := vp.Data().Timestamp*1000/clockRate > (v.startDetected*1000/clockRate)+uint32(v.config.HeadMargin.Milliseconds()) + currentTS := vp.Data().Timestamp + durationGap := (currentTS - v.startDetected) * 1000 / clockRate - isTailMarginPassedAfterStarted := vp.Data().Timestamp*1000/clockRate > (v.startDetected*1000/clockRate)+uint32(v.config.TailMargin.Milliseconds()) + isTailMarginPassedAfterStarted := durationGap > uint32(v.config.TailMargin.Milliseconds()) - // rest start detected timestamp if audio level above threshold after previously start detected + // restart detected timestamp if audio level above threshold after previously start detected if !v.detected && v.startDetected != 0 && isTailMarginPassedAfterStarted && !isThresholdPassed { v.startDetected = 0 return v.detected } + isHeadMarginPassed := durationGap > uint32(v.config.HeadMargin.Milliseconds()) + // detected true after the audio level stay below threshold until pass the head margin if !v.detected && v.startDetected != 0 && isHeadMarginPassed { // start send packet to callback + v.log.Debugf("voice start detected %d ms ago", durationGap) v.detected = true v.lastDetectedTS = vp.Data().Timestamp @@ -223,13 +268,7 @@ func (v *VoiceDetector) isDetected(vp *RetainablePacket) bool { // stop send packet to callback v.detected = false v.startDetected = 0 - v.onVoiceDetected(VoiceActivity{ - TrackID: v.trackID, - StreamID: v.streamID, - SSRC: v.streamInfo.SSRC, - ClockRate: v.streamInfo.ClockRate, - AudioLevels: nil, - }) + return v.detected } @@ -241,8 +280,8 @@ func (v *VoiceDetector) isDetected(vp *RetainablePacket) bool { return v.detected } -func (v *VoiceDetector) addPacket(header *rtp.Header, audioLevel uint8) { - vp, err := v.packetManager.NewPacket(header.SequenceNumber, header.Timestamp, audioLevel) +func (v *VoiceDetector) addPacket(header *rtp.Header, audioLevel uint8, isVoice bool) { + vp, err := v.packetManager.NewPacket(header.SequenceNumber, header.Timestamp, audioLevel, isVoice) if err != nil { v.log.Errorf("failed to create new packet: %v", err) return diff --git a/pkg/interceptors/voiceactivedetector/vad_test.go b/pkg/interceptors/voiceactivedetector/vad_test.go index 161fa78..c305574 100644 --- a/pkg/interceptors/voiceactivedetector/vad_test.go +++ b/pkg/interceptors/voiceactivedetector/vad_test.go @@ -19,7 +19,7 @@ func BenchmarkVAD(b *testing.B) { vad := newVAD(ctx, intc.config, &interceptor.StreamInfo{ ID: "streamID", ClockRate: 48000, - }, leveledLogger) + }) vad.OnVoiceDetected(func(activity VoiceActivity) { // Do nothing @@ -27,7 +27,7 @@ func BenchmarkVAD(b *testing.B) { for i := 0; i < b.N; i++ { header := &rtp.Header{} - vad.addPacket(header, 30) + vad.addPacket(header, 3, false) } vad.cancel()