Skip to content

Commit

Permalink
fix vad
Browse files Browse the repository at this point in the history
  • Loading branch information
Yohan Totting committed Jul 14, 2024
1 parent 316cb3d commit 5fe9dbd
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 43 deletions.
11 changes: 8 additions & 3 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,14 @@
"mode": "auto",
"program": "${workspaceFolder}/examples/http-websocket/main.go",
"env": {
// "PIONS_LOG_DEBUG":"all",
// "PIONS_LOG_TRACE":"all",
// "PIONS_LOG_INFO":"all",
"PIONS_LOG_DEBUG":"sfu,vad",
"PIONS_LOG_TRACE":"sfu,vad",
"PIONS_LOG_INFO":"sfu,vad",
"PIONS_LOG_WARN":"sfu,vad",
"PIONS_LOG_ERROR":"sfu,vad",
"stderrthreshold":"DEBUG",
"logtostderr":"true",

},
"buildFlags": "-race"

Expand Down
20 changes: 19 additions & 1 deletion examples/http-websocket/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@

const videoEl = document.getElementById("video-"+streamid)
if (!videoEl){
console.log("video element not found ",streamid)
return
}

Expand All @@ -287,7 +288,24 @@
videoEl.style.border = "5px solid green"
}

}
let vadEl = document.getElementById("vad-"+streamid)
if (!vadEl){
vadEl = document.createElement("div");
vadEl.id = "vad-"+streamid
const container = document.getElementById("container-"+streamid)
container.appendChild(vadEl)
}

if (vad.data.audioLevels!==null){
vadEl.innerText = Math.floor(vad.data.audioLevels.reduce(function (sum, value) {
return sum + value.audioLevel;
}, 0) / vad.data.audioLevels.length);

} else {
vadEl.innerText = "0"
}
}


function startH264() {
start('h264')
Expand Down
11 changes: 7 additions & 4 deletions examples/http-websocket/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,13 @@ var logger logging.LeveledLogger

func main() {
flag.Set("logtostderr", "true")
flag.Set("stderrthreshold", "INFO")
flag.Set("PIONS_LOG_INFO", "sfu")
flag.Set("PIONS_LOG_DEBUG", "sfu")
flag.Set("PIONS_LOG_TRACE", "sfu")
flag.Set("stderrthreshold", "DEBUG")
flag.Set("PIONS_LOG_INFO", "sfu,vad")

flag.Set("PIONS_LOG_ERROR", "sfu,vad")
flag.Set("PIONS_LOG_WARN", "sfu,vad")
flag.Set("PIONS_LOG_DEBUG", "sfu,vad")
flag.Set("PIONS_LOG_TRACE", "sfu,vad")

flag.Parse()

Expand Down
9 changes: 6 additions & 3 deletions pkg/interceptors/voiceactivedetector/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func (v *Interceptor) BindLocalStream(info *interceptor.StreamInfo, writer inter
defer v.mu.Unlock()

if vad == nil {
v.vads[info.SSRC] = newVAD(v.context, v.config, info, v.log)
v.vads[info.SSRC] = newVAD(v.context, v.config, info)
vad = v.vads[info.SSRC]
}

Expand Down Expand Up @@ -185,9 +185,12 @@ func (v *Interceptor) processPacket(ssrc uint32, header *rtp.Header) rtp.AudioLe
return rtp.AudioLevelExtension{}
}

vad.addPacket(header, audioData.Level)
if audioData.Voice {
vad.addPacket(header, audioData.Level, audioData.Voice)
}

return audioData

}

func (v *Interceptor) getConfig() Config {
Expand Down Expand Up @@ -243,7 +246,7 @@ func (v *Interceptor) MapAudioTrack(ssrc uint32, t webrtc.TrackLocal) *VoiceDete

vad := v.getVadBySSRC(ssrc)
if vad == nil {
vad = newVAD(v.context, v.config, nil, v.log)
vad = newVAD(v.context, v.config, nil)
v.mu.Lock()
v.vads[ssrc] = vad
v.mu.Unlock()
Expand Down
3 changes: 2 additions & 1 deletion pkg/interceptors/voiceactivedetector/packetmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func newPacketManager() *PacketManager {
}
}

func (m *PacketManager) NewPacket(seqNo uint16, timestamp uint32, audioLevel uint8) (*RetainablePacket, error) {
func (m *PacketManager) NewPacket(seqNo uint16, timestamp uint32, audioLevel uint8, isVoice bool) (*RetainablePacket, error) {

p := &RetainablePacket{
onRelease: m.releasePacket,
Expand All @@ -44,6 +44,7 @@ func (m *PacketManager) NewPacket(seqNo uint16, timestamp uint32, audioLevel uin
p.data.SequenceNo = seqNo
p.data.Timestamp = timestamp
p.data.AudioLevel = audioLevel
p.data.IsVoice = isVoice

return p, nil
}
Expand Down
97 changes: 68 additions & 29 deletions pkg/interceptors/voiceactivedetector/vad.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ type VoicePacketData struct {
SequenceNo uint16 `json:"sequenceNo"`
Timestamp uint32 `json:"timestamp"`
AudioLevel uint8 `json:"audioLevel"`
IsVoice bool `json:"isVoice"`
}

type VoiceActivity struct {
Expand Down Expand Up @@ -42,7 +43,7 @@ type VoiceDetector struct {
log logging.LeveledLogger
}

func newVAD(ctx context.Context, config Config, streamInfo *interceptor.StreamInfo, log logging.LeveledLogger) *VoiceDetector {
func newVAD(ctx context.Context, config Config, streamInfo *interceptor.StreamInfo) *VoiceDetector {
v := &VoiceDetector{
context: ctx,
config: config,
Expand All @@ -51,7 +52,7 @@ func newVAD(ctx context.Context, config Config, streamInfo *interceptor.StreamIn
mu: sync.RWMutex{},
VoicePackets: make([]*RetainablePacket, 0),
packetManager: newPacketManager(),
log: log,
log: logging.NewDefaultLoggerFactory().NewLogger("vad"),
}

v.run()
Expand All @@ -68,7 +69,7 @@ func newVAD(ctx context.Context, config Config, streamInfo *interceptor.StreamIn
// once the tail margin close, stop send the packet.
func (v *VoiceDetector) run() {
go func() {
ticker := time.NewTicker(500 * time.Millisecond)
ticker := time.NewTicker(v.config.TailMargin)
ctx, cancel := context.WithCancel(v.context)
v.cancel = cancel

Expand All @@ -77,23 +78,49 @@ func (v *VoiceDetector) run() {
cancel()
}()

active := false
lastSent := time.Now()

for {
select {
case <-ctx.Done():
return
case voicePacket := <-v.channel:

if v.isDetected(voicePacket) {
if voicePacket.data.AudioLevel < v.config.Threshold {
// send all packets to callback
v.sendPacketsToCallback()
activity := VoiceActivity{
TrackID: v.trackID,
StreamID: v.streamID,
SSRC: v.streamInfo.SSRC,
ClockRate: v.streamInfo.ClockRate,
AudioLevels: []*VoicePacketData{voicePacket.data},
}

v.onVoiceDetected(activity)
lastSent = time.Now()
active = true

voicePacket.Release()
}
case <-ticker.C:
go v.dropExpiredPackets()
if active && time.Since(lastSent) > v.config.TailMargin {
// we need to notify that the voice is stopped
v.onVoiceDetected(VoiceActivity{
TrackID: v.trackID,
StreamID: v.streamID,
SSRC: v.streamInfo.SSRC,
ClockRate: v.streamInfo.ClockRate,
AudioLevels: nil,
})
active = false
}
}
}
}()
}

// TODO: this function is use together with isDetected function
// need to fix isDetected function first before we can use this function
func (v *VoiceDetector) dropExpiredPackets() {
loop:
for {
Expand All @@ -118,26 +145,35 @@ loop:
}
}

func (v *VoiceDetector) sendPacketsToCallback() {
func (v *VoiceDetector) sendPacketsToCallback() int {
if v.callback == nil {
return
return 0
}

// get all packets from head margin until tail margin

packets := v.getPackets()

v.onVoiceDetected(VoiceActivity{
TrackID: v.trackID,
StreamID: v.streamID,
SSRC: v.streamInfo.SSRC,
ClockRate: v.streamInfo.ClockRate,
AudioLevels: packets,
})
length := len(packets)

if length > 0 {
activity := VoiceActivity{
TrackID: v.trackID,
StreamID: v.streamID,
SSRC: v.streamInfo.SSRC,
ClockRate: v.streamInfo.ClockRate,
AudioLevels: packets,
}

v.onVoiceDetected(activity)

v.log.Debugf("voice detected: %v", activity)
}

// clear packets
v.clearPackets()

return length
}

func (v *VoiceDetector) clearPackets() {
Expand All @@ -157,7 +193,9 @@ func (v *VoiceDetector) getPackets() []*VoicePacketData {

var packets []*VoicePacketData
for _, packet := range v.VoicePackets {
packets = append(packets, packet.Data())
if packet.Data().AudioLevel < v.config.Threshold {
packets = append(packets, packet.Data())
}
}

return packets
Expand All @@ -182,6 +220,9 @@ func (v *VoiceDetector) OnVoiceDetected(callback func(VoiceActivity)) {
v.callback = callback
}

// TODO: this function is sometimes stop detecting the voice activity
// we need to investigate why this is happening
// for now we're fallback to threshold based detection
func (v *VoiceDetector) isDetected(vp *RetainablePacket) bool {
v.mu.RLock()
v.VoicePackets = append(v.VoicePackets, vp)
Expand All @@ -198,19 +239,23 @@ func (v *VoiceDetector) isDetected(vp *RetainablePacket) bool {
return v.detected
}

isHeadMarginPassed := vp.Data().Timestamp*1000/clockRate > (v.startDetected*1000/clockRate)+uint32(v.config.HeadMargin.Milliseconds())
currentTS := vp.Data().Timestamp
durationGap := (currentTS - v.startDetected) * 1000 / clockRate

isTailMarginPassedAfterStarted := vp.Data().Timestamp*1000/clockRate > (v.startDetected*1000/clockRate)+uint32(v.config.TailMargin.Milliseconds())
isTailMarginPassedAfterStarted := durationGap > uint32(v.config.TailMargin.Milliseconds())

// rest start detected timestamp if audio level above threshold after previously start detected
// restart detected timestamp if audio level above threshold after previously start detected
if !v.detected && v.startDetected != 0 && isTailMarginPassedAfterStarted && !isThresholdPassed {
v.startDetected = 0
return v.detected
}

isHeadMarginPassed := durationGap > uint32(v.config.HeadMargin.Milliseconds())

// detected true after the audio level stay below threshold until pass the head margin
if !v.detected && v.startDetected != 0 && isHeadMarginPassed {
// start send packet to callback
v.log.Debugf("voice start detected %d ms ago", durationGap)
v.detected = true
v.lastDetectedTS = vp.Data().Timestamp

Expand All @@ -223,13 +268,7 @@ func (v *VoiceDetector) isDetected(vp *RetainablePacket) bool {
// stop send packet to callback
v.detected = false
v.startDetected = 0
v.onVoiceDetected(VoiceActivity{
TrackID: v.trackID,
StreamID: v.streamID,
SSRC: v.streamInfo.SSRC,
ClockRate: v.streamInfo.ClockRate,
AudioLevels: nil,
})

return v.detected
}

Expand All @@ -241,8 +280,8 @@ func (v *VoiceDetector) isDetected(vp *RetainablePacket) bool {
return v.detected
}

func (v *VoiceDetector) addPacket(header *rtp.Header, audioLevel uint8) {
vp, err := v.packetManager.NewPacket(header.SequenceNumber, header.Timestamp, audioLevel)
func (v *VoiceDetector) addPacket(header *rtp.Header, audioLevel uint8, isVoice bool) {
vp, err := v.packetManager.NewPacket(header.SequenceNumber, header.Timestamp, audioLevel, isVoice)
if err != nil {
v.log.Errorf("failed to create new packet: %v", err)
return
Expand Down
4 changes: 2 additions & 2 deletions pkg/interceptors/voiceactivedetector/vad_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ func BenchmarkVAD(b *testing.B) {
vad := newVAD(ctx, intc.config, &interceptor.StreamInfo{
ID: "streamID",
ClockRate: 48000,
}, leveledLogger)
})

vad.OnVoiceDetected(func(activity VoiceActivity) {
// Do nothing
})

for i := 0; i < b.N; i++ {
header := &rtp.Header{}
vad.addPacket(header, 30)
vad.addPacket(header, 3, false)
}

vad.cancel()
Expand Down

0 comments on commit 5fe9dbd

Please sign in to comment.