Skip to content

Commit

Permalink
add voice activity detection
Browse files Browse the repository at this point in the history
  • Loading branch information
Yohan Totting committed Oct 31, 2023
1 parent dec7875 commit c5f93de
Show file tree
Hide file tree
Showing 13 changed files with 617 additions and 78 deletions.
7 changes: 7 additions & 0 deletions bitratecontroller.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ type bitrateClaim struct {
lastDecreaseTime time.Time
}

func (c *bitrateClaim) Quality() QualityLevel {
c.mu.Lock()
defer c.mu.Unlock()

return c.quality
}

func (c *bitrateClaim) isAllowToIncrease() bool {
c.mu.Lock()
defer c.mu.Unlock()
Expand Down
87 changes: 63 additions & 24 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"time"

"github.com/golang/glog"
"github.com/inlivedev/sfu/pkg/interceptors/voiceactivedetector"
"github.com/pion/interceptor"
"github.com/pion/interceptor/pkg/cc"
"github.com/pion/interceptor/pkg/stats"
Expand Down Expand Up @@ -49,9 +50,10 @@ var (
)

type ClientOptions struct {
Direction webrtc.RTPTransceiverDirection
IdleTimeout time.Duration
Type string
Direction webrtc.RTPTransceiverDirection
IdleTimeout time.Duration
Type string
EnableVoiceDetection bool
}

type internalDataMessage struct {
Expand Down Expand Up @@ -115,6 +117,7 @@ type Client struct {
onConnectionStateChangedCallbacks []func(webrtc.PeerConnectionState)
onJoinedCallbacks []func()
onLeftCallbacks []func()
onVoiceDetectedCallbacks []func(trackID, streamID string, SSRC uint32, voiceData []voiceactivedetector.VoicePacketData)
onTrackRemovedCallbacks []func(sourceType string, track *webrtc.TrackLocalStaticRTP)
onIceCandidate func(context.Context, *webrtc.ICECandidate)
onBeforeRenegotiation func(context.Context) bool
Expand All @@ -141,21 +144,28 @@ type Client struct {

func DefaultClientOptions() ClientOptions {
return ClientOptions{
Direction: webrtc.RTPTransceiverDirectionSendrecv,
IdleTimeout: 30 * time.Second,
Type: ClientTypePeer,
Direction: webrtc.RTPTransceiverDirectionSendrecv,
IdleTimeout: 30 * time.Second,
Type: ClientTypePeer,
EnableVoiceDetection: false,
}
}

func NewClient(s *SFU, id string, name string, peerConnectionConfig webrtc.Configuration, opts ClientOptions) *Client {
var client *Client
var vadInterceptor *voiceactivedetector.Interceptor

localCtx, cancel := context.WithCancel(s.context)
m := &webrtc.MediaEngine{}

if err := RegisterCodecs(m, s.codecs); err != nil {
panic(err)
}

RegisterSimulcastHeaderExtensions(m, webrtc.RTPCodecTypeVideo)
RegisterAudioLevelHeaderExtension(m)
if opts.EnableVoiceDetection {
voiceactivedetector.RegisterAudioLevelHeaderExtension(m)
}

// // Create a InterceptorRegistry. This is the user configurable RTP/RTCP Pipeline.
// // This provides NACKs, RTCP Reports and other features. If you use `webrtc.NewPeerConnection`
Expand All @@ -175,6 +185,18 @@ func NewClient(s *SFU, id string, name string, peerConnectionConfig webrtc.Confi

i.Add(statsInterceptorFactory)

if opts.EnableVoiceDetection {
glog.Info("client: voice detection is enabled")
vadInterceptorFactory := voiceactivedetector.NewInterceptor(localCtx)

// enable voice detector
vadInterceptorFactory.OnNew(func(i *voiceactivedetector.Interceptor) {
vadInterceptor = i
})

i.Add(vadInterceptorFactory)
}

if err = webrtc.ConfigureTWCCHeaderExtensionSender(m, i); err != nil {
panic(err)
}
Expand All @@ -186,19 +208,6 @@ func NewClient(s *SFU, id string, name string, peerConnectionConfig webrtc.Confi
panic(err)
}

// // Register a intervalpli factory
// // This interceptor sends a PLI every 3 seconds. A PLI causes a video keyframe to be generated by the sender.
// // This makes our video seekable and more error resilent, but at a cost of lower picture quality and higher bitrates
// // A real world application should process incoming RTCP packets from viewers and forward them to senders

// pliOpts := intervalpli.GeneratorInterval(s.pliInterval)
// intervalPliFactory, err := intervalpli.NewReceiverInterceptor(pliOpts)
// if err != nil {
// panic(err)
// }

// i.Add(intervalPliFactory)

settingEngine := webrtc.SettingEngine{}

if s.mux != nil {
Expand All @@ -213,13 +222,13 @@ func NewClient(s *SFU, id string, name string, peerConnectionConfig webrtc.Confi

// add other clients tracks before generate the answer
// s.addOtherClientTracksBeforeSendAnswer(peerConnection)
localCtx, cancel := context.WithCancel(s.context)

var stateNew atomic.Value
stateNew.Store(ClientStateNew)

var quality atomic.Uint32
quality.Store(QualityHigh)
client := &Client{
client = &Client{
id: id,
name: name,
estimatorChan: estimatorChan,
Expand Down Expand Up @@ -266,12 +275,17 @@ func NewClient(s *SFU, id string, name string, peerConnectionConfig webrtc.Confi
// to connected peers
peerConnection.OnTrack(func(remoteTrack *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
var track ITrack
var vad *voiceactivedetector.VoiceDetector

defer glog.Info("client: new track ", remoteTrack.ID(), " Kind:", remoteTrack.Kind(), " Codec: ", remoteTrack.Codec().MimeType, " RID: ", remoteTrack.RID())

if remoteTrack.Kind() == webrtc.RTPCodecTypeAudio && client.IsVADEnabled() {
vad = vadInterceptor.AddAudioTrack(remoteTrack)
}

if remoteTrack.RID() == "" {
// not simulcast
track = newTrack(client, remoteTrack, receiver)
track = newTrack(client, remoteTrack, receiver, vad)
if err := client.tracks.Add(track); err != nil {
glog.Error("client: error add track ", err)
}
Expand Down Expand Up @@ -766,7 +780,12 @@ func (c *Client) OnConnectionStateChanged(callback func(webrtc.PeerConnectionSta
}

func (c *Client) onConnectionStateChanged(state webrtc.PeerConnectionState) {
for _, callback := range c.onConnectionStateChangedCallbacks {
c.mu.Lock()
callbacks := make([]func(webrtc.PeerConnectionState), len(c.onConnectionStateChangedCallbacks))
copy(callbacks, c.onConnectionStateChangedCallbacks)
c.mu.Unlock()

for _, callback := range callbacks {
callback(webrtc.PeerConnectionState(state))
}
}
Expand Down Expand Up @@ -1180,3 +1199,23 @@ func (c *Client) SFU() *SFU {
func (c *Client) OnTracksAvailable(callback func([]ITrack)) {
c.onTracksAvailable = callback
}

func (c *Client) OnVoiceDetected(callback func(trackID, streamID string, ssrc uint32, voiceData []voiceactivedetector.VoicePacketData)) {
c.mu.Lock()
defer c.mu.Unlock()

c.onVoiceDetectedCallbacks = append(c.onVoiceDetectedCallbacks, callback)
}

func (c *Client) onVoiceDetected(trackID, streamID string, ssrc uint32, voiceData []voiceactivedetector.VoicePacketData) {
c.mu.Lock()
defer c.mu.Unlock()

for _, callback := range c.onVoiceDetectedCallbacks {
callback(trackID, streamID, ssrc, voiceData)
}
}

func (c *Client) IsVADEnabled() bool {
return c.options.EnableVoiceDetection
}
14 changes: 0 additions & 14 deletions clienttrack.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,6 @@ func (t *clientTrack) push(rtp *rtp.Packet, quality QualityLevel) {
}
}

func (t *clientTrack) getAudioLevel(p *rtp.Packet) rtp.AudioLevelExtension {
audioLevel := rtp.AudioLevelExtension{}
headerID := t.remoteTrack.getAudioLevelExtensionID()
if headerID != 0 {
ext := p.Header.GetExtension(headerID)
if err := audioLevel.Unmarshal(ext); err != nil {
glog.Error("clienttrack: error on unmarshal audio level", err)
}
}

return audioLevel

}

func (t *clientTrack) getCurrentBitrate() uint32 {
return t.remoteTrack.GetCurrentBitrate()
}
Expand Down
51 changes: 42 additions & 9 deletions examples/http-websocket/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,11 @@
video {
background-color: black;
min-width: 320px;
width: auto;
height: 100%;
width: 100%;
height: auto;
min-height: 240px;
object-fit: contain;
box-sizing: border-box;
}
</style>
<script type="module">
Expand Down Expand Up @@ -202,6 +203,8 @@
}
} else if(msg.type =='track_stats'){
updateTrackStats(msg.data)
} else if (msg.type=='voice_detected'){
updateVoiceDetected(msg.data)
}
} catch (error) {
console.log(error);
Expand All @@ -214,6 +217,23 @@
return promise
};

function updateVoiceDetected(vad){
const videoEl = document.getElementById("video-"+vad.stream_id)
if (!videoEl){
return
}


if (!vad.packets){
// voice ended
videoEl.style.border = "none"
} else {
// voice detected
videoEl.style.border = "5px solid green"
}

}

function startH264() {
start('h264')
}
Expand All @@ -226,7 +246,6 @@
await startWs()
document.getElementById("btnStart").disabled = true;
document.getElementById("btnStartVP9").disabled = true;
const localVideo = document.getElementById("localVideo");

const video = {
width: {ideal: 1280},
Expand All @@ -249,8 +268,25 @@
};

const stream = await navigator.mediaDevices.getUserMedia(constraints)
localVideo.srcObject = stream;
localVideo.play();
let container = document.getElementById("container-"+stream.id);
if (!container) {
container = document.createElement("div");
container.className = "container";
container.id = "container-"+stream.id;
document.querySelector('main').appendChild(container);
}

let localVideo = document.getElementById("video-"+stream.id);
if (!localVideo) {
localVideo = document.createElement("video");
localVideo.id = "video-"+stream.id;
localVideo.autoplay = true;
localVideo.muted = true;
container.appendChild(localVideo);
}

localVideo.srcObject =stream;


peerConnection.ontrack = function(e) {
e.streams.forEach((stream) => {
Expand Down Expand Up @@ -323,7 +359,7 @@
sendEncodings: [
{
maxBitrate: 1200*1000,
scalabilityMode: 'L3T3'
scalabilityMode: 'L3T2'
},

]
Expand Down Expand Up @@ -764,9 +800,6 @@ <h1>HTTP WebSocket Example</h1>
<p>Open the console to see the output.</p>
</header>
<main>
<div id="local">
<video id="localVideo" muted playsinline></video>
</div>
</main>
<aside>
<h3>Outbound Video</h3>
Expand Down
29 changes: 29 additions & 0 deletions examples/http-websocket/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/golang/glog"
"github.com/inlivedev/sfu"
"github.com/inlivedev/sfu/pkg/fakeclient"
"github.com/inlivedev/sfu/pkg/interceptors/voiceactivedetector"
"github.com/pion/webrtc/v3"
"golang.org/x/net/websocket"
)
Expand All @@ -30,6 +31,13 @@ type Respose struct {
Data interface{} `json:"data"`
}

type VAD struct {
SSRC uint32 `json:"ssrc"`
TrackID string `json:"track_id"`
StreamID string `json:"stream_id"`
Packets []voiceactivedetector.VoicePacketData `json:"packets"`
}

const (
TypeOffer = "offer"
TypeAnswer = "answer"
Expand All @@ -45,6 +53,7 @@ const (
TypeBitrateAdjusted = "bitrate_adjusted"
TypePacketLossPercentage = "set_packet_loss_percentage"
TypeTrackStats = "track_stats"
TypeVoiceDetected = "voice_detected"
)

func main() {
Expand Down Expand Up @@ -171,6 +180,7 @@ func clientHandler(isDebug bool, conn *websocket.Conn, messageChan chan Request,
// add a new client to room
// you can also get the client by using r.GetClient(clientID)
opts := sfu.DefaultClientOptions()
opts.EnableVoiceDetection = true
client, err := r.AddClient(clientID, clientID, opts)
if err != nil {
log.Panic(err)
Expand Down Expand Up @@ -289,6 +299,16 @@ func clientHandler(isDebug bool, conn *websocket.Conn, messageChan chan Request,
_, _ = conn.Write(candidateBytes)
})

vadChan := make(chan VAD)
client.OnVoiceDetected(func(trackID, streamID string, ssrc uint32, voiceData []voiceactivedetector.VoicePacketData) {
vadChan <- VAD{
SSRC: ssrc,
TrackID: trackID,
StreamID: streamID,
Packets: voiceData,
}
})

ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()

Expand All @@ -308,6 +328,15 @@ func clientHandler(isDebug bool, conn *websocket.Conn, messageChan chan Request,
respBytes, _ := json.Marshal(resp)
_, _ = conn.Write(respBytes)
}
case vad := <-vadChan:
resp := Respose{
Status: true,
Type: TypeVoiceDetected,
Data: vad,
}

respBytes, _ := json.Marshal(resp)
_, _ = conn.Write(respBytes)
case req := <-messageChan:
// handle as SDP if no error
if req.Type == TypeOffer || req.Type == TypeAnswer {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package simulcastinterceptor
package simulcast

import (
"sync"
Expand All @@ -13,7 +13,7 @@ type InterceptorFactory struct {
onNew func(i *Interceptor)
}

func New() *InterceptorFactory {
func NewInterceptor() *InterceptorFactory {
return &InterceptorFactory{}
}

Expand Down
Loading

0 comments on commit c5f93de

Please sign in to comment.