Skip to content

Commit

Permalink
add quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
jmduarte committed Jun 5, 2024
1 parent f1f7670 commit f662bd7
Showing 1 changed file with 153 additions and 1 deletion.
154 changes: 153 additions & 1 deletion notebooks/08_Compression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
"metadata": {},
"outputs": [],
"source": [
"data = fetch_openml(\"hls4ml_lhc_jets_hlf\", parser=\"auto\")\n",
"data = fetch_openml(\"hls4ml_lhc_jets_hlf\")\n",
"X, y = data[\"data\"], data[\"target\"]\n",
"\n",
"le = LabelEncoder()\n",
Expand Down Expand Up @@ -288,6 +288,158 @@
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"converter = tf.lite.TFLiteConverter.from_keras_model(pruned_model)\n",
"converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
"\n",
"tflite_model_quant = converter.convert()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def representative_data_gen():\n",
" for input_value in tf.data.Dataset.from_tensor_slices(X_train_val.astype(np.float32)).batch(1).take(100):\n",
" # Model has only one input so each data point has one element.\n",
" yield [input_value]\n",
"\n",
"\n",
"converter = tf.lite.TFLiteConverter.from_keras_model(model)\n",
"converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
"converter.representative_dataset = representative_data_gen\n",
"\n",
"tflite_model_quant = converter.convert()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pathlib\n",
"\n",
"tflite_models_dir = pathlib.Path(\"tflite_models/\")\n",
"tflite_models_dir.mkdir(exist_ok=True, parents=True)\n",
"\n",
"# Save the quantized model:\n",
"tflite_model_quant_file = tflite_models_dir / \"model_quant.tflite\"\n",
"tflite_model_quant_file.write_bytes(tflite_model_quant)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Helper function to run inference on a TFLite model\n",
"def run_tflite_model(tflite_file, X_test_indices):\n",
" global X_test\n",
"\n",
" # Initialize the interpreter\n",
" interpreter = tf.lite.Interpreter(model_path=str(tflite_file))\n",
" interpreter.allocate_tensors()\n",
"\n",
" input_details = interpreter.get_input_details()[0]\n",
" output_details = interpreter.get_output_details()[0]\n",
"\n",
" predictions = np.zeros((len(X_test_indices), 5), dtype=np.float32)\n",
" for i, X_test_index in enumerate(X_test_indices):\n",
" X_test_i = X_test[X_test_index]\n",
"\n",
" # Check if the input type is quantized, then rescale input data to uint8\n",
" if input_details[\"dtype\"] == np.uint8:\n",
" input_scale, input_zero_point = input_details[\"quantization\"]\n",
" X_test_i = X_test_i / input_scale + input_zero_point\n",
"\n",
" X_test_i = np.expand_dims(X_test_i, axis=0).astype(input_details[\"dtype\"])\n",
" interpreter.set_tensor(input_details[\"index\"], X_test_i)\n",
" interpreter.invoke()\n",
" output = interpreter.get_tensor(output_details[\"index\"])[0]\n",
" predictions[i] = output\n",
"\n",
" return predictions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"X_test_indices = list(range(0, len(X_test)))\n",
"\n",
"y_quant = run_tflite_model(tflite_model_quant_file, X_test_indices)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(\"Accuracy pruned+quantized: {}\".format(accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_quant, axis=1))))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, ax = plt.subplots(figsize=(9, 9))\n",
"_ = plotting.make_roc(y_test, y_ref, classes)\n",
"plt.gca().set_prop_cycle(None) # reset the colors\n",
"_ = plotting.make_roc(y_test, y_prune, classes, linestyle=\"--\")\n",
"plt.gca().set_prop_cycle(None) # reset the colors\n",
"_ = plotting.make_roc(y_test, y_quant, classes, linestyle=\"-.\")\n",
"\n",
"from matplotlib.lines import Line2D\n",
"\n",
"lines = [Line2D([0], [0], ls=\"-\"), Line2D([0], [0], ls=\"--\"), Line2D([0], [0], ls=\"-.\")]\n",
"from matplotlib.legend import Legend\n",
"\n",
"leg = Legend(ax, lines, labels=[\"Unpruned\", \"Pruned\", \"Quantized\"], loc=\"lower right\", frameon=False)\n",
"ax.add_artist(leg)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(y_quant)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(y_prune)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(y_ref)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down

0 comments on commit f662bd7

Please sign in to comment.