diff --git a/Examples/WhisperAX/WhisperAX/Views/ContentView.swift b/Examples/WhisperAX/WhisperAX/Views/ContentView.swift index 19320c5..664703b 100644 --- a/Examples/WhisperAX/WhisperAX/Views/ContentView.swift +++ b/Examples/WhisperAX/WhisperAX/Views/ContentView.swift @@ -51,7 +51,8 @@ struct ContentView: View { @AppStorage("silenceThreshold") private var silenceThreshold: Double = 0.3 @AppStorage("useVAD") private var useVAD: Bool = true @AppStorage("tokenConfirmationsNeeded") private var tokenConfirmationsNeeded: Double = 2 - @AppStorage("chunkingStrategy") private var chunkingStrategy: ChunkingStrategy = .none + @AppStorage("concurrentWorkerCount") private var concurrentWorkerCount: Int = 4 + @AppStorage("chunkingStrategy") private var chunkingStrategy: ChunkingStrategy = .vad @AppStorage("encoderComputeUnits") private var encoderComputeUnits: MLComputeUnits = .cpuAndNeuralEngine @AppStorage("decoderComputeUnits") private var decoderComputeUnits: MLComputeUnits = .cpuAndNeuralEngine @@ -1269,12 +1270,15 @@ struct ContentView: View { func transcribeCurrentFile(path: String) async throws { // Load and convert buffer in a limited scope + Logging.debug("Loading audio file: \(path)") + let loadingStart = Date() let audioFileSamples = try await Task { try autoreleasepool { - let audioFileBuffer = try AudioProcessor.loadAudio(fromPath: path) - return AudioProcessor.convertBufferToArray(buffer: audioFileBuffer) + return try AudioProcessor.loadAudioAsFloatArray(fromPath: path) } }.value + Logging.debug("Loaded audio file in \(Date().timeIntervalSince(loadingStart)) seconds") + let transcription = try await transcribeAudioSamples(audioFileSamples) @@ -1316,6 +1320,7 @@ struct ContentView: View { withoutTimestamps: !enableTimestamps, wordTimestamps: true, clipTimestamps: seekClip, + concurrentWorkerCount: concurrentWorkerCount, chunkingStrategy: chunkingStrategy ) diff --git a/Sources/WhisperKit/Core/Audio/AudioChunker.swift b/Sources/WhisperKit/Core/Audio/AudioChunker.swift index 325d41a..fadec10 100644 --- a/Sources/WhisperKit/Core/Audio/AudioChunker.swift +++ b/Sources/WhisperKit/Core/Audio/AudioChunker.swift @@ -81,7 +81,6 @@ open class VADAudioChunker: AudioChunking { // Typically this will be the full audio file, unless seek points are explicitly provided var startIndex = seekClipStart while startIndex < seekClipEnd - windowPadding { - let currentFrameLength = audioArray.count guard startIndex >= 0 && startIndex < audioArray.count else { throw WhisperError.audioProcessingFailed("startIndex is outside the buffer size") } diff --git a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift index 89edeab..89d3132 100644 --- a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift +++ b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift @@ -93,8 +93,6 @@ public extension AudioProcessing { } static func padOrTrimAudio(fromArray audioArray: [Float], startAt startIndex: Int = 0, toLength frameLength: Int = 480_000, saveSegment: Bool = false) -> MLMultiArray? { - let currentFrameLength = audioArray.count - guard startIndex >= 0 && startIndex < audioArray.count else { Logging.error("startIndex is outside the buffer size") return nil @@ -197,7 +195,15 @@ public class AudioProcessor: NSObject, AudioProcessing { let audioFileURL = URL(fileURLWithPath: audioFilePath) let audioFile = try AVAudioFile(forReading: audioFileURL, commonFormat: .pcmFormatFloat32, interleaved: false) + return try loadAudio(fromFile: audioFile, startTime: startTime, endTime: endTime, maxReadFrameSize: maxReadFrameSize) + } + public static func loadAudio( + fromFile audioFile: AVAudioFile, + startTime: Double? = 0, + endTime: Double? = nil, + maxReadFrameSize: AVAudioFrameCount? = nil + ) throws -> AVAudioPCMBuffer { let sampleRate = audioFile.fileFormat.sampleRate let channelCount = audioFile.fileFormat.channelCount let frameLength = AVAudioFrameCount(audioFile.length) @@ -243,13 +249,56 @@ public class AudioProcessor: NSObject, AudioProcessing { } } + public static func loadAudioAsFloatArray( + fromPath audioFilePath: String, + startTime: Double? = 0, + endTime: Double? = nil + ) throws -> [Float] { + guard FileManager.default.fileExists(atPath: audioFilePath) else { + throw WhisperError.loadAudioFailed("Resource path does not exist \(audioFilePath)") + } + + let audioFileURL = URL(fileURLWithPath: audioFilePath) + let audioFile = try AVAudioFile(forReading: audioFileURL, commonFormat: .pcmFormatFloat32, interleaved: false) + let inputSampleRate = audioFile.fileFormat.sampleRate + let inputFrameCount = AVAudioFrameCount(audioFile.length) + let inputDuration = Double(inputFrameCount) / inputSampleRate + + let start = startTime ?? 0 + let end = min(endTime ?? inputDuration, inputDuration) + + // Load 10m of audio at a time to reduce peak memory while converting + // Particularly impactful for large audio files + let chunkDuration: Double = 60 * 10 + var currentTime = start + var result: [Float] = [] + + while currentTime < end { + let chunkEnd = min(currentTime + chunkDuration, end) + + try autoreleasepool { + let buffer = try loadAudio( + fromFile: audioFile, + startTime: currentTime, + endTime: chunkEnd + ) + + let floatArray = Self.convertBufferToArray(buffer: buffer) + result.append(contentsOf: floatArray) + } + + currentTime = chunkEnd + } + + return result + } + public static func loadAudio(at audioPaths: [String]) async -> [Result<[Float], Swift.Error>] { await withTaskGroup(of: [(index: Int, result: Result<[Float], Swift.Error>)].self) { taskGroup -> [Result<[Float], Swift.Error>] in for (index, audioPath) in audioPaths.enumerated() { taskGroup.addTask { do { - let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioPath) - let audio = AudioProcessor.convertBufferToArray(buffer: audioBuffer) + let audio = try AudioProcessor.loadAudioAsFloatArray(fromPath: audioPath) return [(index: index, result: .success(audio))] } catch { return [(index: index, result: .failure(error))] @@ -280,10 +329,10 @@ public class AudioProcessor: NSObject, AudioProcessing { frameCount: AVAudioFrameCount? = nil, maxReadFrameSize: AVAudioFrameCount = Constants.defaultAudioReadFrameSize ) -> AVAudioPCMBuffer? { - let inputFormat = audioFile.fileFormat + let inputSampleRate = audioFile.fileFormat.sampleRate let inputStartFrame = audioFile.framePosition let inputFrameCount = frameCount ?? AVAudioFrameCount(audioFile.length) - let inputDuration = Double(inputFrameCount) / inputFormat.sampleRate + let inputDuration = Double(inputFrameCount) / inputSampleRate let endFramePosition = min(inputStartFrame + AVAudioFramePosition(inputFrameCount), audioFile.length + 1) guard let outputFormat = AVAudioFormat(standardFormatWithSampleRate: sampleRate, channels: channelCount) else { @@ -305,8 +354,8 @@ public class AudioProcessor: NSObject, AudioProcessing { let remainingFrames = AVAudioFrameCount(endFramePosition - audioFile.framePosition) let framesToRead = min(remainingFrames, maxReadFrameSize) - let currentPositionInSeconds = Double(audioFile.framePosition) / inputFormat.sampleRate - let nextPositionInSeconds = (Double(audioFile.framePosition) + Double(framesToRead)) / inputFormat.sampleRate + let currentPositionInSeconds = Double(audioFile.framePosition) / inputSampleRate + let nextPositionInSeconds = (Double(audioFile.framePosition) + Double(framesToRead)) / inputSampleRate Logging.debug("Resampling \(String(format: "%.2f", currentPositionInSeconds))s - \(String(format: "%.2f", nextPositionInSeconds))s") do { @@ -644,7 +693,7 @@ public class AudioProcessor: NSObject, AudioProcessing { &propertySize, &name ) - if status == noErr, let deviceNameCF = name?.takeUnretainedValue() as String? { + if status == noErr, let deviceNameCF = name?.takeRetainedValue() as String? { deviceName = deviceNameCF } diff --git a/Sources/WhisperKit/Core/TextDecoder.swift b/Sources/WhisperKit/Core/TextDecoder.swift index 5026f15..a3bc208 100644 --- a/Sources/WhisperKit/Core/TextDecoder.swift +++ b/Sources/WhisperKit/Core/TextDecoder.swift @@ -591,9 +591,11 @@ open class TextDecoder: TextDecoding, WhisperMLModel { var hasAlignment = false var isFirstTokenLogProbTooLow = false let windowUUID = UUID() - DispatchQueue.global().async { [weak self] in + Task { [weak self] in guard let self = self else { return } - self.shouldEarlyStop[windowUUID] = false + await MainActor.run { + self.shouldEarlyStop[windowUUID] = false + } } for tokenIndex in prefilledIndex.. [Float] { + guard let data = floatChannelData?.pointee else { + throw WhisperError.audioProcessingFailed("Error converting audio, missing floatChannelData") + } + return Array(UnsafeBufferPointer(start: data, count: Int(frameLength))) + } + /// Appends the contents of another buffer to the current buffer func appendContents(of buffer: AVAudioPCMBuffer) -> Bool { return appendContents(of: buffer, startingFrame: 0, frameCount: buffer.frameLength) @@ -446,8 +454,9 @@ public func modelSupport(for deviceName: String, from config: ModelSupportConfig /// Deprecated @available(*, deprecated, message: "Subject to removal in a future version. Use modelSupport(for:from:) -> ModelSupport instead.") @_disfavoredOverload +@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) public func modelSupport(for deviceName: String, from config: ModelSupportConfig? = nil) -> (default: String, disabled: [String]) { - let modelSupport = modelSupport(for: deviceName, from: config) + let modelSupport: ModelSupport = modelSupport(for: deviceName, from: config) return (modelSupport.default, modelSupport.disabled) } diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index db8d07e..e2db1f8 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -446,7 +446,8 @@ open class WhisperKit { open func detectLanguage( audioPath: String ) async throws -> (language: String, langProbs: [String: Float]) { - let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioPath) + // Only need the first 30s for language detection + let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioPath, endTime: 30.0) let audioArray = AudioProcessor.convertBufferToArray(buffer: audioBuffer) return try await detectLangauge(audioArray: audioArray) } @@ -721,15 +722,17 @@ open class WhisperKit { callback: TranscriptionCallback = nil ) async throws -> [TranscriptionResult] { // Process input audio file into audio samples - let loadAudioStart = Date() - let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioPath) - let loadTime = Date().timeIntervalSince(loadAudioStart) + let audioArray = try await withThrowingTaskGroup(of: [Float].self) { group -> [Float] in + let convertAudioStart = Date() + defer { + let convertTime = Date().timeIntervalSince(convertAudioStart) + currentTimings.audioLoading = convertTime + Logging.debug("Audio loading and convert time: \(convertTime)") + logCurrentMemoryUsage("Audio Loading and Convert") + } - let convertAudioStart = Date() - let audioArray = AudioProcessor.convertBufferToArray(buffer: audioBuffer) - let convertTime = Date().timeIntervalSince(convertAudioStart) - currentTimings.audioLoading = loadTime + convertTime - Logging.debug("Audio loading time: \(loadTime), Audio convert time: \(convertTime)") + return try AudioProcessor.loadAudioAsFloatArray(fromPath: audioPath) + } let transcribeResults: [TranscriptionResult] = try await transcribe( audioArray: audioArray, @@ -837,23 +840,23 @@ open class WhisperKit { throw WhisperError.tokenizerUnavailable() } - let childProgress = Progress() - progress.totalUnitCount += 1 - progress.addChild(childProgress, withPendingUnitCount: 1) - - let transcribeTask = TranscribeTask( - currentTimings: currentTimings, - progress: childProgress, - audioEncoder: audioEncoder, - featureExtractor: featureExtractor, - segmentSeeker: segmentSeeker, - textDecoder: textDecoder, - tokenizer: tokenizer - ) - do { try Task.checkCancellation() + let childProgress = Progress() + progress.totalUnitCount += 1 + progress.addChild(childProgress, withPendingUnitCount: 1) + + let transcribeTask = TranscribeTask( + currentTimings: currentTimings, + progress: childProgress, + audioEncoder: audioEncoder, + featureExtractor: featureExtractor, + segmentSeeker: segmentSeeker, + textDecoder: textDecoder, + tokenizer: tokenizer + ) + let transcribeTaskResult = try await transcribeTask.run( audioArray: audioArray, decodeOptions: decodeOptions, diff --git a/Sources/WhisperKitCLI/CLIArguments.swift b/Sources/WhisperKitCLI/CLIArguments.swift index b76439b..0df2b73 100644 --- a/Sources/WhisperKitCLI/CLIArguments.swift +++ b/Sources/WhisperKitCLI/CLIArguments.swift @@ -103,9 +103,9 @@ struct CLIArguments: ParsableArguments { @Flag(help: "Simulate streaming transcription using the input audio file") var streamSimulated: Bool = false - @Option(help: "Maximum concurrent inference, might be helpful when processing more than 1 audio file at the same time. 0 means unlimited") - var concurrentWorkerCount: Int = 0 + @Option(help: "Maximum concurrent inference, might be helpful when processing more than 1 audio file at the same time. 0 means unlimited. Default: 4") + var concurrentWorkerCount: Int = 4 - @Option(help: "Chunking strategy for audio processing, `nil` means no chunking, `vad` means using voice activity detection") - var chunkingStrategy: String? = nil + @Option(help: "Chunking strategy for audio processing, `none` means no chunking, `vad` means using voice activity detection. Default: `vad`") + var chunkingStrategy: String = "vad" } diff --git a/Sources/WhisperKitCLI/TranscribeCLI.swift b/Sources/WhisperKitCLI/TranscribeCLI.swift index 9a5f31c..62e8ec8 100644 --- a/Sources/WhisperKitCLI/TranscribeCLI.swift +++ b/Sources/WhisperKitCLI/TranscribeCLI.swift @@ -38,10 +38,8 @@ struct TranscribeCLI: AsyncParsableCommand { cliArguments.audioPath = audioFiles.map { audioFolder + "/" + $0 } } - if let chunkingStrategyRaw = cliArguments.chunkingStrategy { - if ChunkingStrategy(rawValue: chunkingStrategyRaw) == nil { - throw ValidationError("Wrong chunking strategy \"\(chunkingStrategyRaw)\", valid strategies: \(ChunkingStrategy.allCases.map { $0.rawValue })") - } + if ChunkingStrategy(rawValue: cliArguments.chunkingStrategy) == nil { + throw ValidationError("Wrong chunking strategy \"\(cliArguments.chunkingStrategy)\", valid strategies: \(ChunkingStrategy.allCases.map { $0.rawValue })") } } @@ -318,12 +316,6 @@ struct TranscribeCLI: AsyncParsableCommand { } private func decodingOptions(task: DecodingTask) -> DecodingOptions { - let chunkingStrategy: ChunkingStrategy? = - if let chunkingStrategyRaw = cliArguments.chunkingStrategy { - ChunkingStrategy(rawValue: chunkingStrategyRaw) - } else { - nil - } return DecodingOptions( verbose: cliArguments.verbose, task: task, @@ -344,7 +336,7 @@ struct TranscribeCLI: AsyncParsableCommand { firstTokenLogProbThreshold: cliArguments.firstTokenLogProbThreshold, noSpeechThreshold: cliArguments.noSpeechThreshold ?? 0.6, concurrentWorkerCount: cliArguments.concurrentWorkerCount, - chunkingStrategy: chunkingStrategy + chunkingStrategy: ChunkingStrategy(rawValue: cliArguments.chunkingStrategy) ) } diff --git a/Tests/WhisperKitTests/UnitTests.swift b/Tests/WhisperKitTests/UnitTests.swift index 5a0ca5c..cabebe9 100644 --- a/Tests/WhisperKitTests/UnitTests.swift +++ b/Tests/WhisperKitTests/UnitTests.swift @@ -548,9 +548,11 @@ final class UnitTests: XCTestCase { } func testDecodingEarlyStopping() async throws { + let earlyStopTokenCount = 10 let options = DecodingOptions() let continuationCallback: TranscriptionCallback = { (progress: TranscriptionProgress) -> Bool? in - false + // Stop after only 10 tokens (full test audio contains 16) + return progress.tokens.count <= earlyStopTokenCount } let result = try await XCTUnwrapAsync( @@ -576,6 +578,7 @@ final class UnitTests: XCTestCase { XCTAssertNotNil(resultWithWait) let tokenCountWithWait = resultWithWait.segments.flatMap { $0.tokens }.count let decodingTimePerTokenWithWait = resultWithWait.timings.decodingLoop / Double(tokenCountWithWait) + Logging.debug("Decoding loop without wait: \(result.timings.decodingLoop), with wait: \(resultWithWait.timings.decodingLoop)") // Assert that the decoding predictions per token are not slower with the waiting XCTAssertEqual(decodingTimePerTokenWithWait, decodingTimePerToken, accuracy: decodingTimePerToken, "Decoding predictions per token should not be significantly slower with waiting")