Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix error to signal loss in notebook #12

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 18 additions & 16 deletions guitar_lstm_colab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"\n",
"import tensorflow as tf\n",
"from tensorflow.keras import Sequential\n",
"from tensorflow.keras.layers import LSTM, Conv1D, Dense\n",
"from tensorflow.keras.layers import LSTM, Conv1D, Dense, TimeDistributed\n",
"from tensorflow.keras.optimizers import Adam\n",
"from tensorflow.keras.backend import clear_session\n",
"from tensorflow.keras.activations import tanh, elu, relu\n",
Expand Down Expand Up @@ -76,7 +76,7 @@
" \n",
" def __init__(self, x, y, window_len, batch_size=32):\n",
" self.x = x\n",
" self.y = y[window_len-1:] \n",
" self.y = y\n",
" self.window_len = window_len\n",
" self.batch_size = batch_size\n",
" \n",
Expand All @@ -85,7 +85,7 @@
" \n",
" def __getitem__(self, index):\n",
" x_out = np.stack([self.x[idx: idx+self.window_len] for idx in range(index*self.batch_size, (index+1)*self.batch_size)])\n",
" y_out = self.y[index*self.batch_size:(index+1)*self.batch_size]\n",
" y_out = np.stack([self.y[idx: idx+self.window_len] for idx in range(index*self.batch_size, (index+1)*self.batch_size)])\n",
" return x_out, y_out"
]
},
Expand All @@ -98,14 +98,14 @@
"outputs": [],
"source": [
"def pre_emphasis_filter(x, coeff=0.95):\n",
" return tf.concat([x, x - coeff * x], 1)\n",
" return tf.concat([x[:, 0:1, :], x[:, 1:, :] - coeff*x[:, :-1, :]], axis=1)\n",
" \n",
"def error_to_signal(y_true, y_pred): \n",
" \"\"\"\n",
" Error to signal ratio with pre-emphasis filter:\n",
" \"\"\"\n",
" y_true, y_pred = pre_emphasis_filter(y_true), pre_emphasis_filter(y_pred)\n",
" return K.sum(tf.pow(y_true - y_pred, 2), axis=0) / (K.sum(tf.pow(y_true, 2), axis=0) + 1e-10)\n",
" return K.sum(tf.pow(y_true - y_pred, 2), axis=1) / (K.sum(tf.pow(y_true, 2), axis=1) + 1e-10)\n",
" \n",
"def save_wav(name, data):\n",
" wavfile.write(name, 44100, data.flatten().astype(np.float32))\n",
Expand Down Expand Up @@ -136,21 +136,21 @@
" --training_mode=2 Extended training (set max_epochs as desired, for example 50+)\n",
"'''\n",
"\n",
"batch_size = 4096 \n",
"batch_size = 4096\n",
"test_size = 0.2\n",
"\n",
"if train_mode == 0: # Speed Training\n",
" learning_rate = 0.01 \n",
" learning_rate = 3e-4\n",
" conv1d_strides = 12 \n",
" conv1d_filters = 16\n",
" hidden_units = 36\n",
"elif train_mode == 1: # Accuracy Training (~10x longer than Speed Training)\n",
" learning_rate = 0.01 \n",
" learning_rate = 3e-4 \n",
" conv1d_strides = 4\n",
" conv1d_filters = 36\n",
" hidden_units= 64\n",
"else: # Extended Training (~60x longer than Accuracy Training)\n",
" learning_rate = 0.0005 \n",
" learning_rate = 3e-4\n",
" conv1d_strides = 3\n",
" conv1d_filters = 36\n",
" hidden_units= 96\n",
Expand All @@ -159,11 +159,11 @@
"# Create Sequential Model ###########################################\n",
"clear_session()\n",
"model = Sequential()\n",
"model.add(Conv1D(conv1d_filters, 12,strides=conv1d_strides, activation=None, padding='same',input_shape=(input_size,1)))\n",
"model.add(Conv1D(conv1d_filters, 12,strides=conv1d_strides, activation=None, padding='same'))\n",
"model.add(LSTM(hidden_units))\n",
"model.add(Dense(1, activation=None))\n",
"model.compile(optimizer=Adam(learning_rate=learning_rate), loss='mse', metrics=[error_to_signal])\n",
"model.add(Conv1D(conv1d_filters, 12, activation=None, padding='same',input_shape=(input_size,1)))\n",
"model.add(Conv1D(conv1d_filters, 12, activation=None, padding='same'))\n",
"model.add(LSTM(hidden_units, return_sequences=True))\n",
"model.add(TimeDistributed(Dense(1, activation=None)))\n",
"model.compile(optimizer=Adam(learning_rate=learning_rate), loss=error_to_signal, metrics=['mae', 'mse', error_to_signal])\n",
"model.summary()\n",
"\n",
"# Load and Preprocess Data ###########################################\n",
Expand All @@ -180,7 +180,7 @@
"val_arr = WindowArray(X_all[train_examples:], y_all[train_examples:], input_size, batch_size=batch_size)\n",
"\n",
"# Train Model ###################################################\n",
"history = model.fit(train_arr, validation_data=val_arr, epochs=epochs, shuffle=True) \n",
"history = model.fit(train_arr, validation_data=val_arr, epochs=1, shuffle=True) \n",
"model.save('models/'+name+'/'+name+'.h5')\n",
"\n",
"# Run Prediction #################################################\n",
Expand All @@ -194,7 +194,9 @@
"\n",
"prediction = model.predict(test_arr)\n",
"\n",
"save_wav('models/'+name+'/y_pred.wav', prediction)\n",
"# The full prediction has a lot of redundant frames, \n",
"# so we only pick the last frame from each sample in the window\n",
"save_wav('models/'+name+'/y_pred.wav', prediction[:, -1, :])\n",
"save_wav('models/'+name+'/x_test.wav', x_last_part)\n",
"save_wav('models/'+name+'/y_test.wav', y_test)\n",
"\n",
Expand Down