Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a sample notebook for running navigator in parallel #432

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
258 changes: 258 additions & 0 deletions sdk_blueprints/Gretel_Navigator_IAPI_Parallel_Blueprint.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "awlODvx7fQeB"
},
"source": [
"\n",
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gretelai/gretel-blueprints/blob/main/sdk_blueprints/Gretel_Navigator_IAPI_Parallel.ipynb)\n",
"\n",
"<br>\n",
"\n",
"<center><img src=\"https://gretel-public-website.s3.us-west-2.amazonaws.com/assets/brand/gretel_brand_wordmark.svg\" alt=\"Gretel\" width=\"350\"/></center>\n",
"\n",
"<br>\n",
"\n",
"## 👋 Welcome to the Navigator real-time inference API Parallel Blueprint!\n",
"\n",
"In this Blueprint, we will speed up the rate of generation by parallelizing requests to Navigator.\n",
"\n",
"\n",
"<br>\n",
"\n",
"## ✅ Set up your Gretel account\n",
"\n",
"To get started, you will need a [free Gretel account](https://console.gretel.ai/).\n",
"\n",
"<br>\n",
"\n",
"#### Ready? Let's go 🚀"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HplFGj5HNZiJ"
},
"source": [
"## 💾 Install `gretel-client` and its dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IZPDLpEPIXSW"
},
"outputs": [],
"source": [
"%%capture\n",
"!pip install gretel-client"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JHqesHdmOCe_"
},
"source": [
"## 🛜 Configure your Gretel session\n",
"\n",
"- [The Gretel object](https://docs.gretel.ai/create-synthetic-data/gretel-sdk/the-gretel-object) provides a high-level interface for streamlining interactions with Gretel's APIs.\n",
"\n",
"- Retrieve your Gretel API key [here](https://console.gretel.ai/users/me/key)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CLIgnOzcNpHD"
},
"outputs": [],
"source": [
"from gretel_client import Gretel\n",
"\n",
"gretel = Gretel(api_key=\"prompt\", validate=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7GO2SLE3ebdw"
},
"source": [
"## 🚀 Real-time inference API\n",
"\n",
"- The Navigator real-time inference API makes it possible to programmatically run Navigator outside the [Gretel Console](https://console.gretel.ai/navigator).\n",
"\n",
"- Our [Python SDK](https://github.com/gretelai/gretel-python-client) provides an intuitive high-level interface for the Navigator API.\n",
"\n",
"- Navigator currently supports two data generation modes: `\"tabular\"` and `\"natural_language\"`. In both modes, you can choose the backend model that powers the generation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xsiUQH5hP9TV"
},
"outputs": [],
"source": [
"# list \"tabular\" backend models\n",
"gretel.factories.get_navigator_model_list(\"tabular\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0zGw50JLQDWP"
},
"outputs": [],
"source": [
"# list \"natural_language\" backend models\n",
"gretel.factories.get_navigator_model_list(\"natural_language\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PrbIPm_5QKX4"
},
"source": [
"**Notes:**\n",
"\n",
"- `gretelai/auto` automatically selects the current default model, which will change with time as models continue to evolve.\n",
"\n",
"- The `factories` attribute of the `Gretel` object provides methods for creating new objects that interact with Gretel's APIs."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ru2Ogl83BSqn"
},
"source": [
"## 📊 Parallel tabular data generation\n",
"\n",
"- We use the `initialize_navigator_api` method of the `factories` attribute to create a Navigator API object per each thread.\n",
"\n",
"- With `model_type = \"tabular\"` (which is the default), we initialize Navigator's tabular API.\n",
"\n",
"- To select a different backend model, use the optional `backend_model` argument, which we've set to `gretelai/auto`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1InUug1Aeahi"
},
"outputs": [],
"source": [
"import random\n",
"from concurrent.futures import ThreadPoolExecutor\n",
"from threading import Lock\n",
"\n",
"import pandas as pd\n",
"from tqdm import tqdm\n",
"\n",
"\n",
"def generate_random_params():\n",
" \"\"\"\n",
" Generate random values for LLM parameters to ensure moderate creativity.\n",
"\n",
" Returns:\n",
" dict: A dictionary containing random values for temperature, top_p, and top_k.\n",
" \"\"\"\n",
" params = {\n",
" \"temperature\": round(\n",
" random.uniform(0.5, 0.75), 2\n",
" ), # Random float between 0.5 and 0.9\n",
" \"top_p\": round(\n",
" random.uniform(0.8, 0.95), 2\n",
" ), # Random float between 0.8 and 1.0\n",
" \"top_k\": random.randint(30, 45), # Random integer between 30 and 50\n",
" }\n",
" return params\n",
"\n",
"\n",
"def generate_records_parallel(prompt: str, num_records=25, num_threads=5):\n",
" shared_df = pd.DataFrame()\n",
"\n",
" mutex = Lock()\n",
"\n",
" def generate_data(progress: tqdm):\n",
" tabular = gretel.factories.initialize_navigator_api(\n",
" \"tabular\", backend_model=\"gretelai/auto\"\n",
" )\n",
" nonlocal mutex, shared_df\n",
" GENERATE_PARAMS = generate_random_params()\n",
" try:\n",
" for item in tabular.generate(\n",
" prompt,\n",
" num_records=num_records,\n",
" stream=True,\n",
" disable_progress_bar=True,\n",
" **GENERATE_PARAMS\n",
" ):\n",
" with mutex:\n",
" shared_df = pd.concat(\n",
" [shared_df, pd.DataFrame(item, index=[0])], ignore_index=True\n",
" )\n",
" progress.update(1)\n",
" except Exception as e:\n",
" print(\"Error!\")\n",
" print(e)\n",
"\n",
" with tqdm(total=num_records * num_threads) as progress, ThreadPoolExecutor(\n",
" num_threads\n",
" ) as executor:\n",
" for _ in range(1, num_records * num_threads + 1, num_records):\n",
" executor.submit(generate_data, progress)\n",
" return shared_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9T7EydHVS4vd"
},
"outputs": [],
"source": [
"prompt = \"\"\"\n",
"Generate customer bank transaction data. Include the following columns:\n",
"- customer_name\n",
"- customer_id\n",
"- transaction_date\n",
"- transaction_amount\n",
"- transaction_type\n",
"- transaction_category\n",
"- account_balance\n",
"\"\"\"\n",
"num_records = 25\n",
"num_threads = 5\n",
"\n",
"df = generate_records_parallel(prompt, num_records=num_records, num_threads=num_threads)"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Loading