From 858e5271cadb7048ff28e668f481b9a8fe0d62c8 Mon Sep 17 00:00:00 2001 From: Amit Abitbul Date: Sat, 21 Dec 2024 23:50:38 +0200 Subject: [PATCH] Add support in WSS --- whisper_live/client.py | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/whisper_live/client.py b/whisper_live/client.py index 4045ce2..6ee2d15 100644 --- a/whisper_live/client.py +++ b/whisper_live/client.py @@ -70,7 +70,10 @@ def __init__( self.audio_bytes = None if host is not None and port is not None: - socket_url = f"ws://{host}:{port}" + if port == 443: + socket_url = f"wss://{host}" + else: + socket_url = f"ws://{host}:{port}" self.client_socket = websocket.WebSocketApp( socket_url, on_open=lambda ws: self.on_open(ws), @@ -79,6 +82,7 @@ def __init__( on_close=lambda ws, close_status_code, close_msg: self.on_close( ws, close_status_code, close_msg ), + header=['User-Agent: WhisperLiveClient'] ) else: print("[ERROR]: No host or port specified.") @@ -112,9 +116,9 @@ def process_segments(self, segments): for i, seg in enumerate(segments): if not text or text[-1] != seg["text"]: text.append(seg["text"]) - if i == len(segments) - 1 and not seg.get("completed", False): + if i == len(segments) - 1: self.last_segment = seg - elif (self.server_backend == "faster_whisper" and seg.get("completed", False) and + elif (self.server_backend == "faster_whisper" and (not self.transcript or float(seg['start']) >= float(self.transcript[-1]['end']))): self.transcript.append(seg) @@ -203,9 +207,7 @@ def on_open(self, ws): "language": self.language, "task": self.task, "model": self.model, - "use_vad": self.use_vad, - "max_clients": self.max_clients, - "max_connection_time": self.max_connection_time, + "use_vad": self.use_vad } ) ) @@ -259,9 +261,7 @@ def write_srt_file(self, output_path="output.srt"): """ if self.server_backend == "faster_whisper": - if not self.transcript and self.last_segment is not None: - self.transcript.append(self.last_segment) - elif self.last_segment and self.transcript[-1]["text"] != self.last_segment["text"]: + if (self.last_segment): self.transcript.append(self.last_segment) utils.create_srt_file(self.transcript, output_path) @@ -689,15 +689,8 @@ def __init__( output_recording_filename="./output_recording.wav", output_transcription_path="./output.srt", log_transcription=True, - max_clients=4, - max_connection_time=600, ): - self.client = Client( - host, port, lang, translate, model, srt_file_path=output_transcription_path, - use_vad=use_vad, log_transcription=log_transcription, max_clients=max_clients, - max_connection_time=max_connection_time - ) - + self.client = Client(host, port, lang, translate, model, srt_file_path=output_transcription_path, use_vad=use_vad, log_transcription=log_transcription) if save_output_recording and not output_recording_filename.endswith(".wav"): raise ValueError(f"Please provide a valid `output_recording_filename`: {output_recording_filename}") if not output_transcription_path.endswith(".srt"):