From 2770d842664c3162a99a3b5d4a4c15232bfc76ad Mon Sep 17 00:00:00 2001 From: ZachNagengast Date: Sun, 6 Oct 2024 17:00:19 -0700 Subject: [PATCH 01/10] Release memory when transcribing single files Co-authored-by: keleftheriou --- Sources/WhisperKit/Core/Utils.swift | 8 ++++++++ Sources/WhisperKit/Core/WhisperKit.swift | 21 +++++++++++++-------- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/Sources/WhisperKit/Core/Utils.swift b/Sources/WhisperKit/Core/Utils.swift index 8713510..1551aaf 100644 --- a/Sources/WhisperKit/Core/Utils.swift +++ b/Sources/WhisperKit/Core/Utils.swift @@ -188,6 +188,14 @@ public extension String { } extension AVAudioPCMBuffer { + /// Converts the buffer to a float array + func asFloatArray() throws -> [Float] { + guard let data = floatChannelData?.pointee else { + throw WhisperError.audioProcessingFailed("Error converting audio, missing floatChannelData") + } + return Array(UnsafeBufferPointer(start: data, count: Int(frameLength))) + } + /// Appends the contents of another buffer to the current buffer func appendContents(of buffer: AVAudioPCMBuffer) -> Bool { return appendContents(of: buffer, startingFrame: 0, frameCount: buffer.frameLength) diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index c1b66d5..3de1fe2 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -690,15 +690,20 @@ open class WhisperKit { callback: TranscriptionCallback = nil ) async throws -> [TranscriptionResult] { // Process input audio file into audio samples - let loadAudioStart = Date() - let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioPath) - let loadTime = Date().timeIntervalSince(loadAudioStart) + let audioArray = try await withThrowingTaskGroup(of: [Float].self) { group -> [Float] in + let loadAudioStart = Date() + let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioPath) + let loadTime = Date().timeIntervalSince(loadAudioStart) + + let convertAudioStart = Date() + defer { + let convertTime = Date().timeIntervalSince(convertAudioStart) + currentTimings.audioLoading = loadTime + convertTime + Logging.debug("Audio loading time: \(loadTime), Audio convert time: \(convertTime)") + } - let convertAudioStart = Date() - let audioArray = AudioProcessor.convertBufferToArray(buffer: audioBuffer) - let convertTime = Date().timeIntervalSince(convertAudioStart) - currentTimings.audioLoading = loadTime + convertTime - Logging.debug("Audio loading time: \(loadTime), Audio convert time: \(convertTime)") + return AudioProcessor.convertBufferToArray(buffer: audioBuffer) + } let transcribeResults: [TranscriptionResult] = try await transcribe( audioArray: audioArray, From e3078a807e0c667c9df3f063fae7ad077c925302 Mon Sep 17 00:00:00 2001 From: ZachNagengast Date: Sun, 6 Oct 2024 18:08:15 -0700 Subject: [PATCH 02/10] Add method to load from file into float array iteratively - Reduces peak memory by doing the array conversion while loading in chunks so the array copy size is lower - Previously copied the entire buffer which spiked the memory 2x --- .../WhisperAX/Views/ContentView.swift | 7 ++- .../Core/Audio/AudioProcessor.swift | 55 ++++++++++++++++++- Sources/WhisperKit/Core/WhisperKit.swift | 14 ++--- 3 files changed, 64 insertions(+), 12 deletions(-) diff --git a/Examples/WhisperAX/WhisperAX/Views/ContentView.swift b/Examples/WhisperAX/WhisperAX/Views/ContentView.swift index 2a182fb..88a8c4f 100644 --- a/Examples/WhisperAX/WhisperAX/Views/ContentView.swift +++ b/Examples/WhisperAX/WhisperAX/Views/ContentView.swift @@ -1267,12 +1267,15 @@ struct ContentView: View { func transcribeCurrentFile(path: String) async throws { // Load and convert buffer in a limited scope + Logging.debug("Loading audio file: \(path)") + let loadingStart = Date() let audioFileSamples = try await Task { try autoreleasepool { - let audioFileBuffer = try AudioProcessor.loadAudio(fromPath: path) - return AudioProcessor.convertBufferToArray(buffer: audioFileBuffer) + return try AudioProcessor.loadAudioAsFloatArray(fromPath: path) } }.value + Logging.debug("Loaded audio file in \(Date().timeIntervalSince(loadingStart)) seconds") + let transcription = try await transcribeAudioSamples(audioFileSamples) diff --git a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift index c3958cb..74922b1 100644 --- a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift +++ b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift @@ -197,7 +197,15 @@ public class AudioProcessor: NSObject, AudioProcessing { let audioFileURL = URL(fileURLWithPath: audioFilePath) let audioFile = try AVAudioFile(forReading: audioFileURL, commonFormat: .pcmFormatFloat32, interleaved: false) + return try loadAudio(fromFile: audioFile, startTime: startTime, endTime: endTime, maxReadFrameSize: maxReadFrameSize) + } + public static func loadAudio( + fromFile audioFile: AVAudioFile, + startTime: Double? = 0, + endTime: Double? = nil, + maxReadFrameSize: AVAudioFrameCount? = nil + ) throws -> AVAudioPCMBuffer { let sampleRate = audioFile.fileFormat.sampleRate let channelCount = audioFile.fileFormat.channelCount let frameLength = AVAudioFrameCount(audioFile.length) @@ -243,13 +251,56 @@ public class AudioProcessor: NSObject, AudioProcessing { } } + public static func loadAudioAsFloatArray( + fromPath audioFilePath: String, + startTime: Double? = 0, + endTime: Double? = nil + ) throws -> [Float] { + guard FileManager.default.fileExists(atPath: audioFilePath) else { + throw WhisperError.loadAudioFailed("Resource path does not exist \(audioFilePath)") + } + + let audioFileURL = URL(fileURLWithPath: audioFilePath) + let audioFile = try AVAudioFile(forReading: audioFileURL, commonFormat: .pcmFormatFloat32, interleaved: false) + let inputFormat = audioFile.fileFormat + let inputFrameCount = AVAudioFrameCount(audioFile.length) + let inputDuration = Double(inputFrameCount) / inputFormat.sampleRate + + let start = startTime ?? 0 + let end = min(endTime ?? inputDuration, inputDuration) + + // Load 10m of audio at a time to reduce peak memory while converting + // Particularly impactful for large audio files + let chunkDuration: Double = 60 * 10 + var currentTime = start + var result: [Float] = [] + + while currentTime < end { + let chunkEnd = min(currentTime + chunkDuration, end) + + try autoreleasepool { + let buffer = try loadAudio( + fromFile: audioFile, + startTime: currentTime, + endTime: chunkEnd + ) + + let floatArray = Self.convertBufferToArray(buffer: buffer) + result.append(contentsOf: floatArray) + } + + currentTime = chunkEnd + } + + return result + } + public static func loadAudio(at audioPaths: [String]) async -> [Result<[Float], Swift.Error>] { await withTaskGroup(of: [(index: Int, result: Result<[Float], Swift.Error>)].self) { taskGroup -> [Result<[Float], Swift.Error>] in for (index, audioPath) in audioPaths.enumerated() { taskGroup.addTask { do { - let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioPath) - let audio = AudioProcessor.convertBufferToArray(buffer: audioBuffer) + let audio = try AudioProcessor.loadAudioAsFloatArray(fromPath: audioPath) return [(index: index, result: .success(audio))] } catch { return [(index: index, result: .failure(error))] diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index 3de1fe2..744a371 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -415,7 +415,8 @@ open class WhisperKit { open func detectLanguage( audioPath: String ) async throws -> (language: String, langProbs: [String: Float]) { - let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioPath) + // Only need the first 30s for language detection + let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioPath, endTime: 30.0) let audioArray = AudioProcessor.convertBufferToArray(buffer: audioBuffer) return try await detectLangauge(audioArray: audioArray) } @@ -691,18 +692,15 @@ open class WhisperKit { ) async throws -> [TranscriptionResult] { // Process input audio file into audio samples let audioArray = try await withThrowingTaskGroup(of: [Float].self) { group -> [Float] in - let loadAudioStart = Date() - let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioPath) - let loadTime = Date().timeIntervalSince(loadAudioStart) - let convertAudioStart = Date() defer { let convertTime = Date().timeIntervalSince(convertAudioStart) - currentTimings.audioLoading = loadTime + convertTime - Logging.debug("Audio loading time: \(loadTime), Audio convert time: \(convertTime)") + currentTimings.audioLoading = convertTime + Logging.debug("Audio loading and convert time: \(convertTime)") + logCurrentMemoryUsage("Audio Loading and Convert") } - return AudioProcessor.convertBufferToArray(buffer: audioBuffer) + return try AudioProcessor.loadAudioAsFloatArray(fromPath: audioPath) } let transcribeResults: [TranscriptionResult] = try await transcribe( From baea188625a9b4c8d533c00ffd6be4a66d5b0294 Mon Sep 17 00:00:00 2001 From: ZachNagengast Date: Sun, 6 Oct 2024 19:22:44 -0700 Subject: [PATCH 03/10] Fix leak --- Sources/WhisperKit/Core/Audio/AudioProcessor.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift index 74922b1..c136325 100644 --- a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift +++ b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift @@ -695,7 +695,7 @@ public class AudioProcessor: NSObject, AudioProcessing { &propertySize, &name ) - if status == noErr, let deviceNameCF = name?.takeUnretainedValue() as String? { + if status == noErr, let deviceNameCF = name?.takeRetainedValue() as String? { deviceName = deviceNameCF } From 33759ed805cefda10572cd7367960b35e77b80a8 Mon Sep 17 00:00:00 2001 From: ZachNagengast Date: Mon, 7 Oct 2024 16:50:50 -0700 Subject: [PATCH 04/10] Use vad by default in examples --- .../WhisperAX/Views/ContentView.swift | 4 ++- Sources/WhisperKit/Core/WhisperKit.swift | 28 +++++++++---------- Sources/WhisperKitCLI/CLIArguments.swift | 8 +++--- 3 files changed, 21 insertions(+), 19 deletions(-) diff --git a/Examples/WhisperAX/WhisperAX/Views/ContentView.swift b/Examples/WhisperAX/WhisperAX/Views/ContentView.swift index 88a8c4f..6aea8f5 100644 --- a/Examples/WhisperAX/WhisperAX/Views/ContentView.swift +++ b/Examples/WhisperAX/WhisperAX/Views/ContentView.swift @@ -51,7 +51,8 @@ struct ContentView: View { @AppStorage("silenceThreshold") private var silenceThreshold: Double = 0.3 @AppStorage("useVAD") private var useVAD: Bool = true @AppStorage("tokenConfirmationsNeeded") private var tokenConfirmationsNeeded: Double = 2 - @AppStorage("chunkingStrategy") private var chunkingStrategy: ChunkingStrategy = .none + @AppStorage("concurrentWorkerCount") private var concurrentWorkerCount: Int = 4 + @AppStorage("chunkingStrategy") private var chunkingStrategy: ChunkingStrategy = .vad @AppStorage("encoderComputeUnits") private var encoderComputeUnits: MLComputeUnits = .cpuAndNeuralEngine @AppStorage("decoderComputeUnits") private var decoderComputeUnits: MLComputeUnits = .cpuAndNeuralEngine @@ -1317,6 +1318,7 @@ struct ContentView: View { withoutTimestamps: !enableTimestamps, wordTimestamps: true, clipTimestamps: seekClip, + concurrentWorkerCount: concurrentWorkerCount, chunkingStrategy: chunkingStrategy ) diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index 744a371..26e10b2 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -809,23 +809,23 @@ open class WhisperKit { throw WhisperError.tokenizerUnavailable() } - let childProgress = Progress() - progress.totalUnitCount += 1 - progress.addChild(childProgress, withPendingUnitCount: 1) - - let transcribeTask = TranscribeTask( - currentTimings: currentTimings, - progress: childProgress, - audioEncoder: audioEncoder, - featureExtractor: featureExtractor, - segmentSeeker: segmentSeeker, - textDecoder: textDecoder, - tokenizer: tokenizer - ) - do { try Task.checkCancellation() + let childProgress = Progress() + progress.totalUnitCount += 1 + progress.addChild(childProgress, withPendingUnitCount: 1) + + let transcribeTask = TranscribeTask( + currentTimings: currentTimings, + progress: childProgress, + audioEncoder: audioEncoder, + featureExtractor: featureExtractor, + segmentSeeker: segmentSeeker, + textDecoder: textDecoder, + tokenizer: tokenizer + ) + let transcribeTaskResult = try await transcribeTask.run( audioArray: audioArray, decodeOptions: decodeOptions, diff --git a/Sources/WhisperKitCLI/CLIArguments.swift b/Sources/WhisperKitCLI/CLIArguments.swift index b76439b..6d10404 100644 --- a/Sources/WhisperKitCLI/CLIArguments.swift +++ b/Sources/WhisperKitCLI/CLIArguments.swift @@ -103,9 +103,9 @@ struct CLIArguments: ParsableArguments { @Flag(help: "Simulate streaming transcription using the input audio file") var streamSimulated: Bool = false - @Option(help: "Maximum concurrent inference, might be helpful when processing more than 1 audio file at the same time. 0 means unlimited") - var concurrentWorkerCount: Int = 0 + @Option(help: "Maximum concurrent inference, might be helpful when processing more than 1 audio file at the same time. 0 means unlimited. Default: 4") + var concurrentWorkerCount: Int = 4 - @Option(help: "Chunking strategy for audio processing, `nil` means no chunking, `vad` means using voice activity detection") - var chunkingStrategy: String? = nil + @Option(help: "Chunking strategy for audio processing, `none` means no chunking, `vad` means using voice activity detection. Default: `vad`") + var chunkingStrategy: String? = "vad" } From 37a4d4f8816719b558d16c8862fb30fba7c990ad Mon Sep 17 00:00:00 2001 From: ZachNagengast Date: Mon, 7 Oct 2024 19:00:42 -0700 Subject: [PATCH 05/10] Fix vad thread issue --- Sources/WhisperKit/Core/TextDecoder.swift | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/Sources/WhisperKit/Core/TextDecoder.swift b/Sources/WhisperKit/Core/TextDecoder.swift index 5026f15..08550c9 100644 --- a/Sources/WhisperKit/Core/TextDecoder.swift +++ b/Sources/WhisperKit/Core/TextDecoder.swift @@ -591,9 +591,11 @@ open class TextDecoder: TextDecoding, WhisperMLModel { var hasAlignment = false var isFirstTokenLogProbTooLow = false let windowUUID = UUID() - DispatchQueue.global().async { [weak self] in + Task { [weak self] in guard let self = self else { return } - self.shouldEarlyStop[windowUUID] = false + await MainActor.run { + self.shouldEarlyStop[windowUUID] = false + } } for tokenIndex in prefilledIndex.. Date: Mon, 7 Oct 2024 19:17:33 -0700 Subject: [PATCH 06/10] Fix unused warning --- Sources/WhisperKit/Core/Audio/AudioChunker.swift | 1 - Sources/WhisperKit/Core/Audio/AudioProcessor.swift | 2 -- 2 files changed, 3 deletions(-) diff --git a/Sources/WhisperKit/Core/Audio/AudioChunker.swift b/Sources/WhisperKit/Core/Audio/AudioChunker.swift index 325d41a..fadec10 100644 --- a/Sources/WhisperKit/Core/Audio/AudioChunker.swift +++ b/Sources/WhisperKit/Core/Audio/AudioChunker.swift @@ -81,7 +81,6 @@ open class VADAudioChunker: AudioChunking { // Typically this will be the full audio file, unless seek points are explicitly provided var startIndex = seekClipStart while startIndex < seekClipEnd - windowPadding { - let currentFrameLength = audioArray.count guard startIndex >= 0 && startIndex < audioArray.count else { throw WhisperError.audioProcessingFailed("startIndex is outside the buffer size") } diff --git a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift index 096d01c..1b01607 100644 --- a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift +++ b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift @@ -93,8 +93,6 @@ public extension AudioProcessing { } static func padOrTrimAudio(fromArray audioArray: [Float], startAt startIndex: Int = 0, toLength frameLength: Int = 480_000, saveSegment: Bool = false) -> MLMultiArray? { - let currentFrameLength = audioArray.count - guard startIndex >= 0 && startIndex < audioArray.count else { Logging.error("startIndex is outside the buffer size") return nil From 6da35b5c498c4fe26e2419d6d0b2fd52766b88d2 Mon Sep 17 00:00:00 2001 From: ZachNagengast Date: Mon, 7 Oct 2024 22:52:47 -0700 Subject: [PATCH 07/10] Revert change to early stop callback --- Sources/WhisperKit/Core/TextDecoder.swift | 10 ++++------ Tests/WhisperKitTests/UnitTests.swift | 5 ++++- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/Sources/WhisperKit/Core/TextDecoder.swift b/Sources/WhisperKit/Core/TextDecoder.swift index 08550c9..a3bc208 100644 --- a/Sources/WhisperKit/Core/TextDecoder.swift +++ b/Sources/WhisperKit/Core/TextDecoder.swift @@ -737,15 +737,13 @@ open class TextDecoder: TextDecoding, WhisperMLModel { // Call the callback if it is provided on a background thread to avoid blocking the decoding loop if let callback = callback { - Task { [weak self] in + DispatchQueue.global().async { [weak self] in guard let self = self else { return } let shouldContinue = callback(result) if let shouldContinue = shouldContinue, !shouldContinue, !isPrefill { - await MainActor.run { - Logging.debug("Early stopping") - if self.shouldEarlyStop.keys.contains(windowUUID) { - self.shouldEarlyStop[windowUUID] = true - } + Logging.debug("Early stopping") + if self.shouldEarlyStop.keys.contains(windowUUID) { + self.shouldEarlyStop[windowUUID] = true } } } diff --git a/Tests/WhisperKitTests/UnitTests.swift b/Tests/WhisperKitTests/UnitTests.swift index 5a0ca5c..55ee757 100644 --- a/Tests/WhisperKitTests/UnitTests.swift +++ b/Tests/WhisperKitTests/UnitTests.swift @@ -548,9 +548,11 @@ final class UnitTests: XCTestCase { } func testDecodingEarlyStopping() async throws { + let earlyStopTokenCount = 10 let options = DecodingOptions() let continuationCallback: TranscriptionCallback = { (progress: TranscriptionProgress) -> Bool? in - false + // Stop after only 10 tokens (full test audio contains 16) + return progress.tokens.count <= earlyStopTokenCount ? true : false } let result = try await XCTUnwrapAsync( @@ -576,6 +578,7 @@ final class UnitTests: XCTestCase { XCTAssertNotNil(resultWithWait) let tokenCountWithWait = resultWithWait.segments.flatMap { $0.tokens }.count let decodingTimePerTokenWithWait = resultWithWait.timings.decodingLoop / Double(tokenCountWithWait) + Logging.debug("Decoding loop without wait: \(result.timings.decodingLoop), with wait: \(resultWithWait.timings.decodingLoop)") // Assert that the decoding predictions per token are not slower with the waiting XCTAssertEqual(decodingTimePerTokenWithWait, decodingTimePerToken, accuracy: decodingTimePerToken, "Decoding predictions per token should not be significantly slower with waiting") From 23b8226ccf61d202044a26a34df620f2480dd182 Mon Sep 17 00:00:00 2001 From: ZachNagengast Date: Mon, 7 Oct 2024 23:02:26 -0700 Subject: [PATCH 08/10] Fix warnings - Optional cli commands are deprecated - @_disfavoredOverload required @available to prevent infinite loop --- Sources/WhisperKit/Core/Utils.swift | 3 ++- Sources/WhisperKitCLI/CLIArguments.swift | 2 +- Sources/WhisperKitCLI/TranscribeCLI.swift | 14 +++----------- 3 files changed, 6 insertions(+), 13 deletions(-) diff --git a/Sources/WhisperKit/Core/Utils.swift b/Sources/WhisperKit/Core/Utils.swift index cdb7735..c9e9a54 100644 --- a/Sources/WhisperKit/Core/Utils.swift +++ b/Sources/WhisperKit/Core/Utils.swift @@ -454,8 +454,9 @@ public func modelSupport(for deviceName: String, from config: ModelSupportConfig /// Deprecated @available(*, deprecated, message: "Subject to removal in a future version. Use modelSupport(for:from:) -> ModelSupport instead.") @_disfavoredOverload +@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) public func modelSupport(for deviceName: String, from config: ModelSupportConfig? = nil) -> (default: String, disabled: [String]) { - let modelSupport = modelSupport(for: deviceName, from: config) + let modelSupport: ModelSupport = modelSupport(for: deviceName, from: config) return (modelSupport.default, modelSupport.disabled) } diff --git a/Sources/WhisperKitCLI/CLIArguments.swift b/Sources/WhisperKitCLI/CLIArguments.swift index 6d10404..0df2b73 100644 --- a/Sources/WhisperKitCLI/CLIArguments.swift +++ b/Sources/WhisperKitCLI/CLIArguments.swift @@ -107,5 +107,5 @@ struct CLIArguments: ParsableArguments { var concurrentWorkerCount: Int = 4 @Option(help: "Chunking strategy for audio processing, `none` means no chunking, `vad` means using voice activity detection. Default: `vad`") - var chunkingStrategy: String? = "vad" + var chunkingStrategy: String = "vad" } diff --git a/Sources/WhisperKitCLI/TranscribeCLI.swift b/Sources/WhisperKitCLI/TranscribeCLI.swift index 9a5f31c..62e8ec8 100644 --- a/Sources/WhisperKitCLI/TranscribeCLI.swift +++ b/Sources/WhisperKitCLI/TranscribeCLI.swift @@ -38,10 +38,8 @@ struct TranscribeCLI: AsyncParsableCommand { cliArguments.audioPath = audioFiles.map { audioFolder + "/" + $0 } } - if let chunkingStrategyRaw = cliArguments.chunkingStrategy { - if ChunkingStrategy(rawValue: chunkingStrategyRaw) == nil { - throw ValidationError("Wrong chunking strategy \"\(chunkingStrategyRaw)\", valid strategies: \(ChunkingStrategy.allCases.map { $0.rawValue })") - } + if ChunkingStrategy(rawValue: cliArguments.chunkingStrategy) == nil { + throw ValidationError("Wrong chunking strategy \"\(cliArguments.chunkingStrategy)\", valid strategies: \(ChunkingStrategy.allCases.map { $0.rawValue })") } } @@ -318,12 +316,6 @@ struct TranscribeCLI: AsyncParsableCommand { } private func decodingOptions(task: DecodingTask) -> DecodingOptions { - let chunkingStrategy: ChunkingStrategy? = - if let chunkingStrategyRaw = cliArguments.chunkingStrategy { - ChunkingStrategy(rawValue: chunkingStrategyRaw) - } else { - nil - } return DecodingOptions( verbose: cliArguments.verbose, task: task, @@ -344,7 +336,7 @@ struct TranscribeCLI: AsyncParsableCommand { firstTokenLogProbThreshold: cliArguments.firstTokenLogProbThreshold, noSpeechThreshold: cliArguments.noSpeechThreshold ?? 0.6, concurrentWorkerCount: cliArguments.concurrentWorkerCount, - chunkingStrategy: chunkingStrategy + chunkingStrategy: ChunkingStrategy(rawValue: cliArguments.chunkingStrategy) ) } From 14461c010b664b6e4cd52bd5ea2d4136b3d6f42b Mon Sep 17 00:00:00 2001 From: Zach Nagengast Date: Tue, 8 Oct 2024 08:39:10 -0700 Subject: [PATCH 09/10] PR review - simplify early stop test logic Co-authored-by: Andrey Leonov --- Tests/WhisperKitTests/UnitTests.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Tests/WhisperKitTests/UnitTests.swift b/Tests/WhisperKitTests/UnitTests.swift index 55ee757..cabebe9 100644 --- a/Tests/WhisperKitTests/UnitTests.swift +++ b/Tests/WhisperKitTests/UnitTests.swift @@ -552,7 +552,7 @@ final class UnitTests: XCTestCase { let options = DecodingOptions() let continuationCallback: TranscriptionCallback = { (progress: TranscriptionProgress) -> Bool? in // Stop after only 10 tokens (full test audio contains 16) - return progress.tokens.count <= earlyStopTokenCount ? true : false + return progress.tokens.count <= earlyStopTokenCount } let result = try await XCTUnwrapAsync( From 0a46e6f521fd07f6f06f4ecadba2b2cfd6315dcc Mon Sep 17 00:00:00 2001 From: ZachNagengast Date: Tue, 8 Oct 2024 08:47:47 -0700 Subject: [PATCH 10/10] Cleanup from review --- Sources/WhisperKit/Core/Audio/AudioProcessor.swift | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift index 1b01607..89d3132 100644 --- a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift +++ b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift @@ -260,9 +260,9 @@ public class AudioProcessor: NSObject, AudioProcessing { let audioFileURL = URL(fileURLWithPath: audioFilePath) let audioFile = try AVAudioFile(forReading: audioFileURL, commonFormat: .pcmFormatFloat32, interleaved: false) - let inputFormat = audioFile.fileFormat + let inputSampleRate = audioFile.fileFormat.sampleRate let inputFrameCount = AVAudioFrameCount(audioFile.length) - let inputDuration = Double(inputFrameCount) / inputFormat.sampleRate + let inputDuration = Double(inputFrameCount) / inputSampleRate let start = startTime ?? 0 let end = min(endTime ?? inputDuration, inputDuration) @@ -329,10 +329,10 @@ public class AudioProcessor: NSObject, AudioProcessing { frameCount: AVAudioFrameCount? = nil, maxReadFrameSize: AVAudioFrameCount = Constants.defaultAudioReadFrameSize ) -> AVAudioPCMBuffer? { - let inputFormat = audioFile.fileFormat + let inputSampleRate = audioFile.fileFormat.sampleRate let inputStartFrame = audioFile.framePosition let inputFrameCount = frameCount ?? AVAudioFrameCount(audioFile.length) - let inputDuration = Double(inputFrameCount) / inputFormat.sampleRate + let inputDuration = Double(inputFrameCount) / inputSampleRate let endFramePosition = min(inputStartFrame + AVAudioFramePosition(inputFrameCount), audioFile.length + 1) guard let outputFormat = AVAudioFormat(standardFormatWithSampleRate: sampleRate, channels: channelCount) else { @@ -354,8 +354,8 @@ public class AudioProcessor: NSObject, AudioProcessing { let remainingFrames = AVAudioFrameCount(endFramePosition - audioFile.framePosition) let framesToRead = min(remainingFrames, maxReadFrameSize) - let currentPositionInSeconds = Double(audioFile.framePosition) / inputFormat.sampleRate - let nextPositionInSeconds = (Double(audioFile.framePosition) + Double(framesToRead)) / inputFormat.sampleRate + let currentPositionInSeconds = Double(audioFile.framePosition) / inputSampleRate + let nextPositionInSeconds = (Double(audioFile.framePosition) + Double(framesToRead)) / inputSampleRate Logging.debug("Resampling \(String(format: "%.2f", currentPositionInSeconds))s - \(String(format: "%.2f", nextPositionInSeconds))s") do {