Skip to content

Commit

Permalink
fix: voice active detection
Browse files Browse the repository at this point in the history
  • Loading branch information
Yohan Totting committed Apr 22, 2024
1 parent c9ab525 commit f9d923a
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 56 deletions.
26 changes: 21 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ type Client struct {
ingressBandwidth *atomic.Uint32
ingressQualityLimitationReason *atomic.Value
isDebug bool
vad *voiceactivedetector.Interceptor
vadInterceptor *voiceactivedetector.Interceptor
}

func DefaultClientOptions() ClientOptions {
Expand All @@ -192,8 +192,7 @@ func DefaultClientOptions() ClientOptions {

func NewClient(s *SFU, id string, name string, peerConnectionConfig webrtc.Configuration, opts ClientOptions) *Client {
var client *Client

var vad *voiceactivedetector.Interceptor
var vadInterceptor *voiceactivedetector.Interceptor

localCtx, cancel := context.WithCancel(s.context)
m := &webrtc.MediaEngine{}
Expand Down Expand Up @@ -233,7 +232,16 @@ func NewClient(s *SFU, id string, name string, peerConnectionConfig webrtc.Confi

// enable voice detector
vadInterceptorFactory.OnNew(func(i *voiceactivedetector.Interceptor) {
vad = i
vadInterceptor = i
i.OnNewVAD(func(vad *voiceactivedetector.VoiceDetector) {
glog.Info("track: voice activity detector enabled")
vad.OnVoiceDetected(func(activity voiceactivedetector.VoiceActivity) {
// send through datachannel
if client != nil {
client.onVoiceDetected(activity)
}
})
})
})

i.Add(vadInterceptorFactory)
Expand Down Expand Up @@ -335,7 +343,7 @@ func NewClient(s *SFU, id string, name string, peerConnectionConfig webrtc.Confi
ingressBandwidth: &atomic.Uint32{},
ingressQualityLimitationReason: &atomic.Value{},
onTracksAvailableCallbacks: make([]func([]ITrack), 0),
vad: vad,
vadInterceptor: vadInterceptor,
}

// make sure the exisiting data channels is created on new clients
Expand Down Expand Up @@ -908,6 +916,14 @@ func (c *Client) setClientTrack(t ITrack) iClientTrack {
return nil
}

if t.Kind() == webrtc.RTPCodecTypeAudio {
if c.IsVADEnabled() {
glog.Info("track: voice activity detector enabled")
ssrc := senderTcv.Sender().GetParameters().Encodings[0].SSRC
c.vadInterceptor.MapAudioTrack(uint32(ssrc), localTrack)
}
}

go func() {
ctx, cancel := context.WithCancel(outputTrack.Context())
defer cancel()
Expand Down
3 changes: 3 additions & 0 deletions clientstats.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ func (c *ClientStats) SetReceiver(id, rid string, stats stats.Stats) {
// UpdateVoiceActivity updates voice activity duration
// 0 timestamp means ended
func (c *ClientStats) UpdateVoiceActivity(ts uint32) {
c.mu.Lock()
defer c.mu.Unlock()

if !c.voiceActivity.active && ts != 0 {
c.voiceActivity.active = true
c.voiceActivity.start = ts
Expand Down
15 changes: 9 additions & 6 deletions examples/http-websocket/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@

let negotiationNeeded = false

const trackStreamIDs = {}

peerConnection.ondatachannel = function(e) {
console.log("ondatachannel: ",e.channel.label);
if (e.channel.label == "internal") {
Expand Down Expand Up @@ -245,7 +247,9 @@
};

function updateVoiceDetected(vad){
const videoEl = document.getElementById("video-"+vad.data.stream_id)
const streamid = vad.data.streamID

const videoEl = document.getElementById("video-"+streamid)
if (!videoEl){
return
}
Expand Down Expand Up @@ -321,20 +325,19 @@

e.streams.forEach((stream) => {
console.log("ontrack", stream, e.track);
const streamid = stream.id.replace('{','').replace('}','');

let container = document.getElementById("container-"+streamid);
let container = document.getElementById("container-"+stream.id);
if (!container) {
container = document.createElement("div");
container.className = "container";
container.id = "container-"+streamid;
container.id = "container-"+stream.id;
document.querySelector('main').appendChild(container);
}

let remoteVideo = document.getElementById("video-"+streamid);
let remoteVideo = document.getElementById("video-"+stream.id);
if (!remoteVideo) {
remoteVideo = document.createElement("video");
remoteVideo.id = "video-"+streamid;
remoteVideo.id = "video-"+stream.id;
remoteVideo.autoplay = true;
container.appendChild(remoteVideo);
if (videoObserver!=null){
Expand Down
52 changes: 32 additions & 20 deletions pkg/interceptors/voiceactivedetector/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,17 @@ func DefaultConfig() Config {
type Interceptor struct {
context context.Context
mu sync.RWMutex
vads map[string]*VoiceDetector
vads map[uint32]*VoiceDetector
config Config
onNew func(vad *VoiceDetector)
}

func new(ctx context.Context) *Interceptor {
return &Interceptor{
context: ctx,
mu: sync.RWMutex{},
config: DefaultConfig(),
vads: make(map[string]*VoiceDetector),
vads: make(map[uint32]*VoiceDetector),
}
}

Expand All @@ -77,14 +78,21 @@ func (v *Interceptor) SetConfig(config Config) {
v.config = config
}

func (v *Interceptor) OnNewVAD(callback func(vad *VoiceDetector)) {
v.mu.Lock()
defer v.mu.Unlock()

v.onNew = callback
}

// BindLocalStream lets you modify any outgoing RTP packets. It is called once for per LocalStream. The returned method
// will be called once per rtp packet.
func (v *Interceptor) BindLocalStream(info *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter {
if info.MimeType != webrtc.MimeTypeOpus && info.MimeType != "audio/red" {
return writer
}

vad := v.getVadByID(info.ID)
vad := v.getVadBySSRC(info.SSRC)
if vad != nil {
vad.updateStreamInfo(info)
}
Expand All @@ -93,27 +101,31 @@ func (v *Interceptor) BindLocalStream(info *interceptor.StreamInfo, writer inter
defer v.mu.Unlock()

if vad == nil {
v.vads[info.ID] = newVAD(v.context, v, info)
v.vads[info.SSRC] = newVAD(v.context, v, info)
vad = v.vads[info.SSRC]
}

if v.onNew != nil {
v.onNew(vad)
}

return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) {
_ = v.processPacket(info.ID, header)
_ = v.processPacket(info.SSRC, header)
return writer.Write(header, payload, attributes)
})
}

// UnbindLocalStream is called when the Stream is removed. It can be used to clean up any data related to that track.
func (v *Interceptor) UnbindLocalStream(info *interceptor.StreamInfo) {
vad := v.getVadByID(info.ID)
vad := v.getVadBySSRC(info.SSRC)
if vad != nil {
vad.Stop()
}

v.mu.Lock()
defer v.mu.Unlock()

delete(v.vads, info.ID)
delete(v.vads, info.SSRC)

}

Expand Down Expand Up @@ -144,27 +156,27 @@ func (v *Interceptor) BindRTCPWriter(writer interceptor.RTCPWriter) interceptor.
return writer
}

func (v *Interceptor) getVadByID(id string) *VoiceDetector {
func (v *Interceptor) getVadBySSRC(ssrc uint32) *VoiceDetector {
v.mu.RLock()
defer v.mu.RUnlock()

vad, ok := v.vads[id]
vad, ok := v.vads[ssrc]
if ok {
return vad
}

return nil
}

func (v *Interceptor) processPacket(id string, header *rtp.Header) rtp.AudioLevelExtension {
audioData := v.getAudioLevel(id, header)
func (v *Interceptor) processPacket(ssrc uint32, header *rtp.Header) rtp.AudioLevelExtension {
audioData := v.getAudioLevel(ssrc, header)
if audioData.Level == 0 {
return rtp.AudioLevelExtension{}
}

vad := v.getVadByID(id)
vad := v.getVadBySSRC(ssrc)
if vad == nil {
glog.Error("vad: not found vad for track id", id)
glog.Error("vad: not found vad for track ssrc", ssrc)
return rtp.AudioLevelExtension{}
}

Expand All @@ -180,9 +192,9 @@ func (v *Interceptor) getConfig() Config {
return v.config
}

func (v *Interceptor) getAudioLevel(id string, header *rtp.Header) rtp.AudioLevelExtension {
func (v *Interceptor) getAudioLevel(ssrc uint32, header *rtp.Header) rtp.AudioLevelExtension {
audioLevel := rtp.AudioLevelExtension{}
headerID := v.getAudioLevelExtensionID(id)
headerID := v.getAudioLevelExtensionID(ssrc)
if headerID != 0 {
ext := header.GetExtension(headerID)
_ = audioLevel.Unmarshal(ext)
Expand All @@ -197,8 +209,8 @@ func RegisterAudioLevelHeaderExtension(m *webrtc.MediaEngine) {
}
}

func (v *Interceptor) getAudioLevelExtensionID(id string) uint8 {
vad := v.getVadByID(id)
func (v *Interceptor) getAudioLevelExtensionID(ssrc uint32) uint8 {
vad := v.getVadBySSRC(ssrc)
if vad != nil {
for _, extension := range vad.streamInfo.RTPHeaderExtensions {
if extension.URI == sdp.AudioLevelURI {
Expand All @@ -211,17 +223,17 @@ func (v *Interceptor) getAudioLevelExtensionID(id string) uint8 {
}

// AddAudioTrack adds audio track to interceptor
func (v *Interceptor) AddAudioTrack(t webrtc.TrackLocal) *VoiceDetector {
func (v *Interceptor) MapAudioTrack(ssrc uint32, t webrtc.TrackLocal) *VoiceDetector {
if t.Kind() != webrtc.RTPCodecTypeAudio {
glog.Error("vad: track is not audio track")
return nil
}

vad := v.getVadByID(t.ID())
vad := v.getVadBySSRC(ssrc)
if vad == nil {
vad = newVAD(v.context, v, nil)
v.mu.Lock()
v.vads[t.ID()] = vad
v.vads[ssrc] = vad
v.mu.Unlock()
}

Expand Down
23 changes: 9 additions & 14 deletions pkg/interceptors/voiceactivedetector/vad.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type VoiceDetector struct {
channel chan VoicePacketData
mu sync.RWMutex
VoicePackets []VoicePacketData
callbacks []func(activity VoiceActivity)
callback func(activity VoiceActivity)
}

func newVAD(ctx context.Context, i *Interceptor, streamInfo *interceptor.StreamInfo) *VoiceDetector {
Expand All @@ -47,7 +47,6 @@ func newVAD(ctx context.Context, i *Interceptor, streamInfo *interceptor.StreamI
channel: make(chan VoicePacketData),
mu: sync.RWMutex{},
VoicePackets: make([]VoicePacketData, 0),
callbacks: make([]func(VoiceActivity), 0),
}

v.run()
Expand Down Expand Up @@ -113,8 +112,7 @@ loop:
}

func (v *VoiceDetector) sendPacketsToCallback() {
noCallbacks := len(v.callbacks) == 0
if noCallbacks {
if v.callback == nil {
return
}

Expand Down Expand Up @@ -147,11 +145,11 @@ func (v *VoiceDetector) getPackets() []VoicePacketData {
}

func (v *VoiceDetector) onVoiceDetected(activity VoiceActivity) {
v.mu.Lock()
defer v.mu.Unlock()
v.mu.RLock()
defer v.mu.RUnlock()

for _, callback := range v.callbacks {
callback(activity)
if v.callback != nil {
v.callback(activity)
}
}

Expand All @@ -160,9 +158,9 @@ func (v *VoiceDetector) OnVoiceDetected(callback func(VoiceActivity)) {
return
}

v.mu.RLock()
defer v.mu.RUnlock()
v.callbacks = append(v.callbacks, callback)
v.mu.Lock()
defer v.mu.Unlock()
v.callback = callback
}

func (v *VoiceDetector) isDetected(vp VoicePacketData) bool {
Expand Down Expand Up @@ -249,7 +247,4 @@ func (v *VoiceDetector) updateStreamInfo(streamInfo *interceptor.StreamInfo) {
defer v.mu.RUnlock()

v.streamInfo = streamInfo
if streamInfo.ID != "" {
v.trackID = streamInfo.ID
}
}
12 changes: 1 addition & 11 deletions track.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"time"

"github.com/golang/glog"
"github.com/inlivedev/sfu/pkg/interceptors/voiceactivedetector"
"github.com/inlivedev/sfu/pkg/networkmonitor"
"github.com/inlivedev/sfu/pkg/rtppool"
"github.com/pion/interceptor/pkg/stats"
Expand Down Expand Up @@ -254,16 +253,7 @@ func (t *Track) subscribe(c *Client) iClientTrack {

}

if t.Kind() == webrtc.RTPCodecTypeAudio {
if c.IsVADEnabled() {
glog.Info("track: voice activity detector enabled")
vad := c.vad.AddAudioTrack(ct.LocalTrack())
vad.OnVoiceDetected(func(activity voiceactivedetector.VoiceActivity) {
// send through datachannel
c.onVoiceDetected(activity)
})
}
} else if t.Kind() == webrtc.RTPCodecTypeVideo {
if t.Kind() == webrtc.RTPCodecTypeVideo {
t.remoteTrack.sendPLI()
}

Expand Down

0 comments on commit f9d923a

Please sign in to comment.