diff --git a/agents/src/multimodal/multimodal_agent.ts b/agents/src/multimodal/multimodal_agent.ts index 4ab66c1..045ef86 100644 --- a/agents/src/multimodal/multimodal_agent.ts +++ b/agents/src/multimodal/multimodal_agent.ts @@ -5,6 +5,8 @@ import type { LocalTrackPublication, RemoteAudioTrack, RemoteParticipant, + RemoteTrack, + RemoteTrackPublication, Room, } from '@livekit/rtc-node'; import { @@ -135,12 +137,17 @@ export class MultimodalAgent extends EventEmitter { this.#updateState(); room.on(RoomEvent.ParticipantConnected, (participant: RemoteParticipant) => { - if (!this.linkedParticipant) { + // automatically link to the first participant that connects, if not already linked + if (this.linkedParticipant) { return; } - this.#linkParticipant(participant.identity); }); + room.on(RoomEvent.TrackPublished, () => { + // in case we are connected before the participant has published, we'd need to re-subscribe + this.#subscribeToMicrophone(); + }); + room.on(RoomEvent.TrackSubscribed, this.#handleTrackSubscription.bind(this)); this.room = room; this.#participant = participant; @@ -297,14 +304,50 @@ export class MultimodalAgent extends EventEmitter { if (this.linkedParticipant.trackPublications.size > 0) { this.#subscribeToMicrophone(); - } else { - this.room.on(RoomEvent.TrackPublished, () => { - this.#subscribeToMicrophone(); - }); + } + + // also check if already subscribed + for (const publication of this.linkedParticipant.trackPublications.values()) { + if (publication.source === TrackSource.SOURCE_MICROPHONE && publication.track) { + this.#handleTrackSubscription(publication.track, publication, this.linkedParticipant); + break; + } } } #subscribeToMicrophone(): void { + if (!this.linkedParticipant) { + this.#logger.error('Participant is not set'); + return; + } + + let microphonePublication: RemoteTrackPublication | undefined = undefined; + for (const publication of this.linkedParticipant.trackPublications.values()) { + if (publication.source === TrackSource.SOURCE_MICROPHONE) { + microphonePublication = publication; + break; + } + } + if (!microphonePublication) { + return; + } + + if (!microphonePublication.subscribed) { + microphonePublication.setSubscribed(true); + } + } + + #handleTrackSubscription( + track: RemoteTrack, + publication: RemoteTrackPublication, + participant: RemoteParticipant, + ) { + if ( + publication.source !== TrackSource.SOURCE_MICROPHONE || + participant.identity !== this.linkedParticipant?.identity + ) { + return; + } const readAudioStreamTask = async (audioStream: AudioStream) => { const bstream = new AudioByteStream( this.model.sampleRate, @@ -319,46 +362,24 @@ export class MultimodalAgent extends EventEmitter { } } }; + this.subscribedTrack = track; - if (!this.linkedParticipant) { - this.#logger.error('Participant is not set'); - return; + if (this.readMicroTask) { + this.readMicroTask.cancel(); } - for (const publication of this.linkedParticipant.trackPublications.values()) { - if (publication.source !== TrackSource.SOURCE_MICROPHONE) { - continue; - } - - if (!publication.subscribed) { - publication.setSubscribed(true); - } - - const track = publication.track; - - if (track && track !== this.subscribedTrack) { - this.subscribedTrack = track; - - if (this.readMicroTask) { - this.readMicroTask.cancel(); - } - - let cancel: () => void; - this.readMicroTask = { - promise: new Promise((resolve, reject) => { - cancel = () => { - reject(new Error('Task cancelled')); - }; - readAudioStreamTask( - new AudioStream(track, this.model.sampleRate, this.model.numChannels), - ) - .then(resolve) - .catch(reject); - }), - cancel: () => cancel(), + let cancel: () => void; + this.readMicroTask = { + promise: new Promise((resolve, reject) => { + cancel = () => { + reject(new Error('Task cancelled')); }; - } - } + readAudioStreamTask(new AudioStream(track, this.model.sampleRate, this.model.numChannels)) + .then(resolve) + .catch(reject); + }), + cancel: () => cancel(), + }; } #getLocalTrackSid(): string | null {