-
Notifications
You must be signed in to change notification settings - Fork 161
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
notebook for RAG using gemma, huggingface and elastic
- Loading branch information
1 parent
41c0a05
commit df2fdd1
Showing
1 changed file
with
359 additions
and
0 deletions.
There are no files selected for viewing
359 changes: 359 additions & 0 deletions
359
notebooks/integrations/gemma/rag-gemma-huggingface-elastic.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,359 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "faa09879-9128-4864-8bb5-945ef9b8e84c", | ||
"metadata": {}, | ||
"source": [ | ||
"# RAG using Google's Gemma, Hugging Face and Elastic" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "d047438b-6f18-47ed-aac9-12c741cefd06", | ||
"metadata": {}, | ||
"source": [ | ||
"In this notebook, our aim is to develop a RAG system utilizing [Google's Gemma](https://ai.google.dev/gemma) model. We'll generate vectors with [Elastic's ELSER](https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-elser.html) model and store them in Elasticsearch. Additionally, we'll explore semantic retrieval techniques and present the top search results as a context window to the Gemma model. Furthermore, we'll utilize the [Hugging Face transformer](https://huggingface.co/google/gemma-2b-it) library to load Gemma on a local environment." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "1bd3acec-d490-4139-bab1-b874e1e7db8d", | ||
"metadata": {}, | ||
"source": [ | ||
"## Setup" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "ef406b8a-03fb-49c5-baed-18e03bcd36d9", | ||
"metadata": {}, | ||
"source": [ | ||
"**Elastic Credentials** - Create an [Elastic Cloud deployment](https://www.elastic.co/search-labs/tutorials/install-elasticsearch/elastic-cloud) to get all Elastic credentials (`ELASTIC_CLOUD_ID`,` ELASTIC_API_KEY`).\n", | ||
"\n", | ||
"**Hugging Face Token** - To get started with the [Gemma](https://huggingface.co/google/gemma-2b-it) model, it is necessary to agree to the terms on Hugging Face and generate the [access token](https://huggingface.co/docs/hub/en/security-tokens) with `write` role.\n", | ||
"\n", | ||
"**Gemma Model** - We're going to use [gemma-2b-it](https://huggingface.co/google/gemma-2b-it), though Google has released 4 open models. You can use any of them i.e. [gemma-2b](https://huggingface.co/google/gemma-2b), [gemma-7b](https://huggingface.co/google/gemma-7b), [gemma-7b-it](https://huggingface.co/google/gemma-7b-it)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "ac91d7a3-1198-4b11-a9c5-50028abc861b", | ||
"metadata": {}, | ||
"source": [ | ||
"## Install packages" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "fda41538-444c-48d7-80a0-b34b2e158b82", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"pip install -q -U elasticsearch langchain transformers huggingface_hub" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "15c2e924-e5a2-439b-8e98-f13a162db7fe", | ||
"metadata": {}, | ||
"source": [ | ||
"## Import packages" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "7219411b-fae6-4c2a-b170-796bc30ed073", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import json\n", | ||
"import os\n", | ||
"from getpass import getpass\n", | ||
"from urllib.request import urlopen\n", | ||
"\n", | ||
"from elasticsearch import Elasticsearch, helpers\n", | ||
"from langchain.text_splitter import CharacterTextSplitter\n", | ||
"from langchain.vectorstores import ElasticsearchStore\n", | ||
"from langchain import HuggingFacePipeline\n", | ||
"from langchain.chains import RetrievalQA\n", | ||
"from langchain.prompts import ChatPromptTemplate\n", | ||
"from langchain.schema.output_parser import StrOutputParser\n", | ||
"from langchain.schema.runnable import RunnablePassthrough\n", | ||
"from huggingface_hub import login\n", | ||
"from transformers import AutoTokenizer, AutoModelForCausalLM\n", | ||
"from transformers import AutoTokenizer, pipeline" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "182a413f-e7fd-4361-8096-90736d3df33e", | ||
"metadata": {}, | ||
"source": [ | ||
"## Get Credentials" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "b184b3a5-0cc8-43f9-b15d-f5ccf48f574b", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"ELASTIC_API_KEY = getpass(\"Elastic API Key :\")\n", | ||
"ELASTIC_CLOUD_ID = getpass(\"Elastic Cloud ID :\")\n", | ||
"elastic_index_name = \"gemma-rag\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "a2efbd81-70b9-409c-ab5f-796d538b42a1", | ||
"metadata": {}, | ||
"source": [ | ||
"## Add documents" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "161dfb9d-f11f-4de5-8489-6464ade0cdb2", | ||
"metadata": {}, | ||
"source": [ | ||
"### Let's download the sample dataset and deserialize the document." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "49427546-7b37-48f4-a6fe-395736ea2d38", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"url = \"https://raw.githubusercontent.com/ashishtiwari1993/langchain-elasticsearch-RAG/main/data.json\"\n", | ||
"\n", | ||
"response = urlopen(url)\n", | ||
"\n", | ||
"workplace_docs = json.loads(response.read())" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "f3bf0104-8b31-4b39-ad21-b372fd1fa0db", | ||
"metadata": {}, | ||
"source": [ | ||
"### Split Documents into Passages" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "79e55ed1-418e-48ed-b3e3-d28e10744eb5", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"metadata = []\n", | ||
"content = []\n", | ||
"\n", | ||
"for doc in workplace_docs:\n", | ||
" content.append(doc[\"content\"])\n", | ||
" metadata.append({\n", | ||
" \"name\": doc[\"name\"],\n", | ||
" \"summary\": doc[\"summary\"],\n", | ||
" \"rolePermissions\":doc[\"rolePermissions\"]\n", | ||
" })\n", | ||
"\n", | ||
"text_splitter = CharacterTextSplitter(chunk_size=50, chunk_overlap=0)\n", | ||
"docs = text_splitter.create_documents(content, metadatas=metadata)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "4264bc1b-23b1-4547-a7f0-670944c3e605", | ||
"metadata": {}, | ||
"source": [ | ||
"## Index Documents into Elasticsearch using ELSER\n", | ||
"\n", | ||
"Before we begin indexing, ensure you have [downloaded and deployed the ELSER model](https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-elser.html#download-deploy-elser) in your deployment and is running on the ML node." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "eb1db78e-e40a-4a5c-9d15-75ee2a1d0994", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"es = ElasticsearchStore.from_documents(\n", | ||
" docs,\n", | ||
" es_cloud_id=ELASTIC_CLOUD_ID,\n", | ||
" es_api_key=ELASTIC_API_KEY,\n", | ||
" index_name=elastic_index_name,\n", | ||
" strategy=ElasticsearchStore.SparseVectorRetrievalStrategy()\n", | ||
")\n", | ||
"\n", | ||
"es" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "02b1ead9-c442-40e9-ba81-d4d286ea878b", | ||
"metadata": {}, | ||
"source": [ | ||
"## Hugging Face login" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "d2f651e4-e760-4b59-a8a3-57c58dfc229f", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from huggingface_hub import notebook_login\n", | ||
"notebook_login()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "7454f551-71a9-4310-bb2a-3fe0e683daab", | ||
"metadata": {}, | ||
"source": [ | ||
"## Initialize the tokenizer with the model (`google/gemma-2b-it`)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "3e1d98eb-0f4e-4c41-a851-125b75502963", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model = AutoModelForCausalLM.from_pretrained(\"google/gemma-2b-it\")\n", | ||
"tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-2b-it\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "11a12596-2ac1-4101-b189-d21d53d33b04", | ||
"metadata": {}, | ||
"source": [ | ||
"## Create a `text-generation` pipeline and initialize with LLM" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "623e74fb-5707-44f7-9dd8-d9499f7ab61e", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"pipe = pipeline(\n", | ||
" \"text-generation\",\n", | ||
" model=model,\n", | ||
" tokenizer=tokenizer,\n", | ||
" max_new_tokens=1024,\n", | ||
")\n", | ||
"\n", | ||
"llm = HuggingFacePipeline(\n", | ||
" pipeline=pipe,\n", | ||
" model_kwargs={\"temperature\": 0.7},\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "49ce0e72-e419-4310-85e9-09077d6c40b2", | ||
"metadata": {}, | ||
"source": [ | ||
"## Format Docs" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "b3c07a75-9220-4a82-a92e-3fc2727ad3ba", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def format_docs(docs):\n", | ||
" print(docs)\n", | ||
" return \"\\n\\n\".join(doc.page_content for doc in docs)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "f6266222-6ec3-495a-8f14-460549bab89d", | ||
"metadata": {}, | ||
"source": [ | ||
"## Create a chain using Prompt template" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "ec203d1a-104b-4583-9ba1-a6b4b0354367", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"retriever=es.as_retriever(search_kwargs={\"k\": 10})\n", | ||
"\n", | ||
"template = \"\"\"Answer the question based only on the following context:\\n\n", | ||
"\n", | ||
"{context}\n", | ||
"\n", | ||
"Question: {question}\n", | ||
"\"\"\"\n", | ||
"prompt = ChatPromptTemplate.from_template(template)\n", | ||
"\n", | ||
"\n", | ||
"chain = (\n", | ||
" {\"context\": retriever | format_docs, \"question\": RunnablePassthrough()}\n", | ||
" | prompt\n", | ||
" | llm\n", | ||
" | StrOutputParser()\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "8ae892dd-7442-4d4d-a804-1d717266e596", | ||
"metadata": {}, | ||
"source": [ | ||
"## Ask question" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "ba312f17-44ae-423d-89a0-ea01eccd85b5", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"q = input(\"Question: \")\n", | ||
"chain.invoke(q)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.4" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |