From 576fc931b5bc79757aa545e8c4acb284719e7f95 Mon Sep 17 00:00:00 2001 From: thecollabagepatch Date: Wed, 5 Jun 2024 01:37:56 -0400 Subject: [PATCH] hopefully no more oom --- cg_backup.py | 291 ++++++++++++++++++++++++++++++++++----------- concurrent_gary.py | 72 +++++++---- g4laudio.py | 54 +++++++-- 3 files changed, 308 insertions(+), 109 deletions(-) diff --git a/cg_backup.py b/cg_backup.py index 5509e88..f482d47 100644 --- a/cg_backup.py +++ b/cg_backup.py @@ -1,6 +1,6 @@ import os import base64 -from flask import Flask, request, jsonify +from flask import Flask, request, jsonify, Response from flask_cors import CORS import yt_dlp as youtube_dl import torch @@ -9,103 +9,250 @@ from audiocraft.data.audio import audio_write import torchaudio.transforms as T from concurrent.futures import ThreadPoolExecutor -from flask import current_app as app import json +import librosa +import soundfile as sf + +from rq import Queue, Retry +from redis import Redis + +from pymongo import MongoClient, errors + +from bson import ObjectId, json_util +import bson # Import bson to handle bson-related errors +import re + +from g4laudio import continue_music + + +# MongoDB connection with retry logic +def get_mongo_client(): + try: + client = MongoClient('mongodb://mongo:27017/', serverSelectionTimeoutMS=600000) + client.admin.command('ping') # Check if the connection is established + return client + except errors.ConnectionFailure as e: + print(f"Could not connect to MongoDB: {e}") + return None + +client = get_mongo_client() +if client: + db = client['name'] + audio_tasks = db.audio_tasks +else: + print("Failed to connect to MongoDB.") + + +# Redis connection +redis_url = os.getenv('REDIS_URL', 'redis://redis:6379/0') +print(f"Connecting to Redis at '{redis_url}'") +redis_conn = Redis.from_url(redis_url) +q = Queue(connection=redis_conn) app = Flask(__name__) CORS(app) -executor = ThreadPoolExecutor(max_workers=2) +executor = ThreadPoolExecutor(max_workers=24) + +def is_valid_youtube_url(url): + youtube_regex = ( + r'(https?://)?(www\.)?' + '(youtube|youtu|youtube-nocookie)\.(com|be)/' + '(watch\?v=|embed/|v/|.+\?v=)?([^&=%\?\s]{11})') + youtube_pattern = re.compile(youtube_regex) + return re.match(youtube_pattern, url) is not None def cleanup_files(*file_paths): - for file_path in file_paths: - if os.path.exists(file_path): - os.remove(file_path) + for file_path in file_paths: + if os.path.exists(file_path): + os.remove(file_path) def download_audio(youtube_url): - downloaded_mp3 = 'downloaded_audio.mp3' - downloaded_webm = 'downloaded_audio.webm' - cleanup_files(downloaded_mp3, downloaded_webm) - ydl_opts = { - 'format': 'bestaudio/best', - 'postprocessors': [{'key': 'FFmpegExtractAudio', 'preferredcodec': 'mp3', 'preferredquality': '192'}], - 'outtmpl': 'downloaded_audio.%(ext)s', - 'keepvideo': True, - } - with youtube_dl.YoutubeDL(ydl_opts) as ydl: - ydl.download([youtube_url]) - return downloaded_mp3, downloaded_webm - -def load_and_preprocess_audio(file_path, timestamp): - song, sr = torchaudio.load(file_path) - device = 'cuda' if torch.cuda.is_available() else 'cpu' - song = song.to(device) - expected_sr = 32000 - if sr != expected_sr: - resampler = T.Resample(sr, expected_sr).to(device) - song = resampler(song) - sr = expected_sr - - # Convert timestamp (seconds) to frames - frame_offset = int(timestamp * sr) - - # Check if waveform duration after timestamp is less than 30 seconds - if song.shape[1] - frame_offset < 30 * sr: - # Wrap around to the beginning of the mp3 - song = torch.cat((song[:, frame_offset:], song[:, :30 * sr - (song.shape[1] - frame_offset)]), dim=1) - else: - song = song[:, frame_offset:frame_offset + 30 * sr] - - # Define the prompt length - prompt_length = 6 * sr - - # Create the prompt waveform - prompt_waveform = song[:, :prompt_length] if song.shape[1] > prompt_length else song - - return prompt_waveform, sr - -def generate_audio_continuation(prompt_waveform, sr): - model_continue = MusicGen.get_pretrained('facebook/musicgen-small') - model_continue.set_generation_params(use_sampling=True, top_k=250, top_p=0.0, temperature=1.0, duration=16, cfg_coef=3) - output = model_continue.generate_continuation(prompt_waveform, prompt_sample_rate=sr, progress=True) + downloaded_mp3 = 'downloaded_audio.mp3' + downloaded_webm = 'downloaded_audio.webm' + cleanup_files(downloaded_mp3, downloaded_webm) + ydl_opts = { + 'format': 'bestaudio/best', + 'postprocessors': [{'key': 'FFmpegExtractAudio', 'preferredcodec': 'mp3', 'preferredquality': '192'}], + 'outtmpl': 'downloaded_audio.%(ext)s', + 'keepvideo': True, + } + with youtube_dl.YoutubeDL(ydl_opts) as ydl: + ydl.download([youtube_url]) + return downloaded_mp3, downloaded_webm + +def get_bpm(downloaded_mp3): + audio, sr = librosa.load(downloaded_mp3, sr=None) + onset_env = librosa.onset.onset_strength(y=audio, sr=sr) + tempo, _ = librosa.beat.beat_track(onset_envelope=onset_env, sr=sr) + if 120 < tempo < 200: + tempo = tempo / 2 + return tempo + +def calculate_duration(bpm, min_duration, max_duration): + single_bar_duration = 4 * 60 / bpm + bars = max(min_duration // single_bar_duration, 1) + + while single_bar_duration * bars < min_duration: + bars += 1 + + duration = single_bar_duration * bars + + while duration > max_duration and bars > 1: + bars -= 1 + duration = single_bar_duration * bars + + return duration + +def load_and_preprocess_audio(file_path, timestamp, promptLength): + song, sr = torchaudio.load(file_path) + device = 'cuda' if torch.cuda.is_available() else 'cpu' + song = song.to(device) + expected_sr = 32000 + if sr != expected_sr: + resampler = T.Resample(sr, expected_sr).to(device) + song = resampler(song) + sr = expected_sr + + # Convert timestamp (seconds) to frames + frame_offset = int(timestamp * sr) + + # Check if waveform duration after timestamp is less than 30 seconds + if song.shape[1] - frame_offset < 30 * sr: + # Wrap around to the beginning of the mp3 + song = torch.cat((song[:, frame_offset:], song[:, :30 * sr - (song.shape[1] - frame_offset)]), dim=1) + else: + song = song[:, frame_offset:frame_offset + 30 * sr] + + # Define the prompt length + prompt_length = promptLength * sr + + # Create the prompt waveform + prompt_waveform = song[:, :prompt_length] if song.shape[1] > prompt_length else song + + return prompt_waveform, sr + +def generate_audio_continuation(prompt_waveform, sr, bpm, model, min_duration, max_duration): + # Calculate the duration to end at a bar + duration = calculate_duration(bpm, min_duration, max_duration) + + # Use a new CUDA stream for this task + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + model_continue = MusicGen.get_pretrained(model) + model_continue.set_generation_params(use_sampling=True, top_k=250, top_p=0.0, temperature=1.0, duration=duration, cfg_coef=3) + output = model_continue.generate_continuation(prompt_waveform, prompt_sample_rate=sr, progress=True) return output.cpu().squeeze(0) def save_generated_audio(output, sr): - output_filename = 'generated_continuation' - audio_write(output_filename, output, sr, strategy="loudness", loudness_compressor=True) - return output_filename + '.wav' + output_filename = 'generated_continuation' + audio_write(output_filename, output, sr, strategy="loudness", loudness_compressor=True) + return output_filename + '.wav' + -def process_youtube_url(youtube_url, timestamp): +def process_youtube_url(youtube_url, timestamp, model, promptLength, min_duration, max_duration): try: downloaded_mp3, downloaded_webm = download_audio(youtube_url) - prompt_waveform, sr = load_and_preprocess_audio(downloaded_mp3, timestamp) - output = generate_audio_continuation(prompt_waveform, sr) + bpm = get_bpm(downloaded_mp3) + prompt_waveform, sr = load_and_preprocess_audio(downloaded_mp3, timestamp, promptLength) + output = generate_audio_continuation(prompt_waveform, sr, bpm, model, min_duration, max_duration) output_filename = save_generated_audio(output, sr) + + # Encode the audio data + with open(output_filename, 'rb') as audio_file: + encoded_audio = base64.b64encode(audio_file.read()).decode('utf-8') + + # Save task info, audio reference, and status in MongoDB + audio_tasks.update_one( + {'youtube_url': youtube_url, 'timestamp': timestamp}, + {'$set': {'output_filename': output_filename, 'status': 'completed', 'audio': encoded_audio}} + ) + cleanup_files(downloaded_mp3, downloaded_webm) return output_filename except Exception as e: print(f"Error processing YouTube URL: {e}") + # Update the task status in MongoDB in case of an error + audio_tasks.update_one( + {'youtube_url': youtube_url, 'timestamp': timestamp}, + {'$set': {'status': 'failed'}} + ) return None @app.route('/generate', methods=['POST']) def generate_audio(): data = request.json youtube_url = data['url'] - print_data = request.get_json() - pretty_data = json.dumps(print_data, indent=4) # Pretty print the JSON data - app.logger.info(f'JSON data received: \n{pretty_data}') # Log the entire JSON data - timestamp = data.get('currentTime') # Get the timestamp, default to 0 if not provided + timestamp = data.get('currentTime', 0) + model = data.get('model', 'facebook/musicgen-small') + promptLength = int(data.get('promptLength', 6)) + duration = data.get('duration', '16-18').split('-') - # Log the timestamp - app.logger.info(f'Timestamp received: {timestamp}') + # Ensure that duration is correctly parsed and handled + min_duration = int(duration[0]) + max_duration = int(duration[1]) - audio_path = process_youtube_url(youtube_url, timestamp) - if audio_path: - with open(audio_path, 'rb') as audio_file: - encoded_audio = base64.b64encode(audio_file.read()).decode('utf-8') - cleanup_files(audio_path) - return jsonify({"audio": encoded_audio}) - else: - return jsonify({"error": "Failed to process audio"}), 500 + + # Validate YouTube URL + if not is_valid_youtube_url(youtube_url): + return jsonify({"error": "Invalid YouTube URL"}), 400 + + # Validate timestamp + if not isinstance(timestamp, (int, float)) or timestamp < 0: + return jsonify({"error": "Invalid timestamp"}), 400 + + # Enqueue the task with retry logic + job = q.enqueue( + process_youtube_url, + youtube_url, + timestamp, + model, + promptLength, + min_duration, + max_duration, + job_timeout=600, + retry=Retry(max=3) + ) + + # Save task info in MongoDB + audio_task = { + 'rq_job_id': job.get_id(), + 'youtube_url': youtube_url, + 'timestamp': timestamp, + 'status': 'pending' + } + task_id = audio_tasks.insert_one(audio_task).inserted_id + + return jsonify({"task_id": str(task_id)}) + +@app.route('/continue', methods=['POST']) +def continue_audio(): + data = request.json + task_id = data['task_id'] + musicgen_model = data['model'] + prompt_duration = int(data.get('prompt_duration', 6)) + + task = audio_tasks.find_one({'_id': ObjectId(task_id)}) + if not task: + return jsonify({"error": "Task not found"}), 404 + + input_data_base64 = task['audio'] + output_data_base64 = continue_music(input_data_base64, musicgen_model, prompt_duration=prompt_duration) + task['audio'] = output_data_base64 + task['status'] = 'completed' + audio_tasks.update_one({'_id': ObjectId(task_id)}, {"$set": task}) + + return jsonify({"task_id": task_id, "audio": output_data_base64}) + +@app.route('/tasks/', methods=['GET']) +def get_task(jobId): + try: + task = audio_tasks.find_one({'_id': ObjectId(jobId)}) + if task: + return Response(json.dumps(task, default=json_util.default), mimetype='application/json') + else: + return jsonify({"error": "Task not found"}), 404 + except bson.errors.InvalidId: + return jsonify({"error": "Invalid ObjectId format"}), 400 if __name__ == '__main__': - app.run(debug=True, threaded=True) \ No newline at end of file + app.run(debug=True, threaded=True) diff --git a/concurrent_gary.py b/concurrent_gary.py index cf814d3..8c04255 100644 --- a/concurrent_gary.py +++ b/concurrent_gary.py @@ -24,7 +24,6 @@ from g4laudio import continue_music - # MongoDB connection with retry logic def get_mongo_client(): try: @@ -42,7 +41,6 @@ def get_mongo_client(): else: print("Failed to connect to MongoDB.") - # Redis connection redis_url = os.getenv('REDIS_URL', 'redis://redis:6379/0') print(f"Connecting to Redis at '{redis_url}'") @@ -62,26 +60,51 @@ def is_valid_youtube_url(url): return re.match(youtube_pattern, url) is not None def cleanup_files(*file_paths): - for file_path in file_paths: - if os.path.exists(file_path): - os.remove(file_path) + for file_path in file_paths: + if os.path.exists(file_path) and file_path.endswith('.webm'): + os.remove(file_path) def download_audio(youtube_url): - downloaded_mp3 = 'downloaded_audio.mp3' - downloaded_webm = 'downloaded_audio.webm' - cleanup_files(downloaded_mp3, downloaded_webm) - ydl_opts = { - 'format': 'bestaudio/best', - 'postprocessors': [{'key': 'FFmpegExtractAudio', 'preferredcodec': 'mp3', 'preferredquality': '192'}], - 'outtmpl': 'downloaded_audio.%(ext)s', - 'keepvideo': True, - } - with youtube_dl.YoutubeDL(ydl_opts) as ydl: - ydl.download([youtube_url]) - return downloaded_mp3, downloaded_webm - -def get_bpm(downloaded_mp3): - audio, sr = librosa.load(downloaded_mp3, sr=None) + cache_dir = '/dataset/gary' + if not os.path.exists(cache_dir): + os.makedirs(cache_dir) + + # Check Redis cache + audio_id = base64.urlsafe_b64encode(youtube_url.encode()).decode('utf-8') + cached_mp3_path = redis_conn.get(audio_id) + + if cached_mp3_path: + cached_mp3_path = cached_mp3_path.decode('utf-8') + if os.path.exists(cached_mp3_path): + print(f"Using cached audio for URL: {youtube_url}") + return cached_mp3_path + + downloaded_mp3 = 'downloaded_audio.mp3' + downloaded_webm = 'downloaded_audio.webm' + cleanup_files(downloaded_webm) + + ydl_opts = { + 'format': 'bestaudio/best', + 'postprocessors': [{'key': 'FFmpegExtractAudio', 'preferredcodec': 'mp3', 'preferredquality': '192'}], + 'outtmpl': 'downloaded_audio.%(ext)s', + 'keepvideo': True, + } + + with youtube_dl.YoutubeDL(ydl_opts) as ydl: + ydl.download([youtube_url]) + + # Move the downloaded file to the cache directory + cached_mp3_path = os.path.join(cache_dir, f'{audio_id}.mp3') + os.rename(downloaded_mp3, cached_mp3_path) + cleanup_files(downloaded_webm) + + # Store the cached file path in Redis + redis_conn.set(audio_id, cached_mp3_path) + + return cached_mp3_path + +def get_bpm(cached_mp3_path): + audio, sr = librosa.load(cached_mp3_path, sr=None) onset_env = librosa.onset.onset_strength(y=audio, sr=sr) tempo, _ = librosa.beat.beat_track(onset_envelope=onset_env, sr=sr) if 120 < tempo < 200: @@ -148,12 +171,11 @@ def save_generated_audio(output, sr): audio_write(output_filename, output, sr, strategy="loudness", loudness_compressor=True) return output_filename + '.wav' - def process_youtube_url(youtube_url, timestamp, model, promptLength, min_duration, max_duration): try: - downloaded_mp3, downloaded_webm = download_audio(youtube_url) - bpm = get_bpm(downloaded_mp3) - prompt_waveform, sr = load_and_preprocess_audio(downloaded_mp3, timestamp, promptLength) + cached_mp3_path = download_audio(youtube_url) + bpm = get_bpm(cached_mp3_path) + prompt_waveform, sr = load_and_preprocess_audio(cached_mp3_path, timestamp, promptLength) output = generate_audio_continuation(prompt_waveform, sr, bpm, model, min_duration, max_duration) output_filename = save_generated_audio(output, sr) @@ -167,7 +189,6 @@ def process_youtube_url(youtube_url, timestamp, model, promptLength, min_duratio {'$set': {'output_filename': output_filename, 'status': 'completed', 'audio': encoded_audio}} ) - cleanup_files(downloaded_mp3, downloaded_webm) return output_filename except Exception as e: print(f"Error processing YouTube URL: {e}") @@ -191,7 +212,6 @@ def generate_audio(): min_duration = int(duration[0]) max_duration = int(duration[1]) - # Validate YouTube URL if not is_valid_youtube_url(youtube_url): return jsonify({"error": "Invalid YouTube URL"}), 400 diff --git a/g4laudio.py b/g4laudio.py index 1859767..61edbf3 100644 --- a/g4laudio.py +++ b/g4laudio.py @@ -12,11 +12,21 @@ def generate_session_id(): """Generate a unique session ID.""" return str(uuid.uuid4()) +# Function to normalize audio to a target peak amplitude +def peak_normalize(y, target_peak=0.9): + return target_peak * (y / np.max(np.abs(y))) + +# Function to normalize audio to a target RMS value +def rms_normalize(y, target_rms=0.05): + current_rms = np.sqrt(np.mean(y**2)) + return y * (target_rms / current_rms) + def preprocess_audio(waveform): waveform_np = waveform.cpu().squeeze().numpy() # Move tensor to CPU and convert to numpy processed_waveform_np = waveform_np # Use the waveform as-is without processing return torch.from_numpy(processed_waveform_np).unsqueeze(0).cuda() # Convert back to tensor and move to GPU +# Function to wrap audio if needed def wrap_audio_if_needed(waveform, sr, desired_duration): current_duration = waveform.shape[-1] / sr while current_duration < desired_duration: @@ -25,9 +35,7 @@ def wrap_audio_if_needed(waveform, sr, desired_duration): return waveform def process_audio(input_data_base64, model_name, progress_callback=None, prompt_duration=6): - # Use a new CUDA stream for this task - stream = torch.cuda.Stream() - with torch.cuda.stream(stream): + try: # Decode the base64 input data input_data = base64.b64decode(input_data_base64) input_audio = io.BytesIO(input_data) @@ -39,8 +47,9 @@ def process_audio(input_data_base64, model_name, progress_callback=None, prompt_ # Model's expected sample rate expected_sr = 32000 # Adjust this value based on your model's requirements - # Resample if necessary + # Check if the input audio's sample rate matches the model's expected sample rate if sr != expected_sr: + # Resample the audio to match the model's expected sample rate resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=expected_sr).cuda() song_resampled = resampler(song) else: @@ -53,7 +62,7 @@ def process_audio(input_data_base64, model_name, progress_callback=None, prompt_ model_continue = MusicGen.get_pretrained(model_name) model_continue.set_custom_progress_callback(progress_callback) - # Set generation parameters + # Setting generation parameters output_duration = song_resampled.shape[-1] / expected_sr model_continue.set_generation_params( use_sampling=True, @@ -77,17 +86,26 @@ def process_audio(input_data_base64, model_name, progress_callback=None, prompt_ output_audio.seek(0) output_data_base64 = base64.b64encode(output_audio.read()).decode('utf-8') - return output_data_base64 + # Clear GPU memory + del song, song_resampled, processed_waveform, prompt_waveform, output + torch.cuda.empty_cache() + + return output_data_base64 + except Exception as e: + print(f"Error processing audio: {e}") + torch.cuda.empty_cache() + raise def continue_music(input_data_base64, musicgen_model, progress_callback=None, prompt_duration=6): - stream = torch.cuda.Stream() - with torch.cuda.stream(stream): + try: + # Decode the base64 input data input_data = base64.b64decode(input_data_base64) input_audio = io.BytesIO(input_data) song, sr = torchaudio.load(input_audio) - song = song.to('cuda') + song = song.to('cuda') # Assume CUDA is available and preferred + # Normalize audio channels if song.size(0) == 1: song = song.repeat(2, 1) # Make stereo if mono @@ -102,23 +120,37 @@ def continue_music(input_data_base64, musicgen_model, progress_callback=None, pr cfg_coef=3.0 ) + # Generate continuation prompt_waveform = song[:, -int(prompt_duration * sr):] output = model_continue.generate_continuation(prompt_waveform, prompt_sample_rate=sr, progress=True) - output = output.squeeze(0) if output.dim() == 3 else output + output = output.squeeze(0) if output.dim() == 3 else output # Ensure 2D tensor for audio + # Resample output if necessary if sr != 32000: resampler = T.Resample(orig_freq=32000, new_freq=sr).to('cuda') output = resampler(output) + # Ensure all tensors are on the same device and have the same number of channels original_minus_prompt = song[:, :-int(prompt_duration * sr)] if original_minus_prompt.size(0) != output.size(0): + # Adjust channel numbers if needed output = output.repeat(original_minus_prompt.size(0), 1) + # Concatenate tensors combined_waveform = torch.cat([original_minus_prompt, output], dim=1).to('cuda') + # Save output output_audio = io.BytesIO() torchaudio.save(output_audio, format='wav', src=combined_waveform.cpu(), sample_rate=sr) output_audio.seek(0) output_data_base64 = base64.b64encode(output_audio.read()).decode('utf-8') - return output_data_base64 + # Clear GPU memory + del song, prompt_waveform, output, combined_waveform + torch.cuda.empty_cache() + + return output_data_base64 + except Exception as e: + print(f"Error continuing music: {e}") + torch.cuda.empty_cache() + raise