Skip to content

Commit

Permalink
notebook for RAG using gemma, huggingface and elastic
Browse files Browse the repository at this point in the history
  • Loading branch information
ashishtiwari1993 committed Mar 5, 2024
1 parent 41c0a05 commit df2fdd1
Showing 1 changed file with 359 additions and 0 deletions.
359 changes: 359 additions & 0 deletions notebooks/integrations/gemma/rag-gemma-huggingface-elastic.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,359 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "faa09879-9128-4864-8bb5-945ef9b8e84c",
"metadata": {},
"source": [
"# RAG using Google's Gemma, Hugging Face and Elastic"
]
},
{
"cell_type": "markdown",
"id": "d047438b-6f18-47ed-aac9-12c741cefd06",
"metadata": {},
"source": [
"In this notebook, our aim is to develop a RAG system utilizing [Google's Gemma](https://ai.google.dev/gemma) model. We'll generate vectors with [Elastic's ELSER](https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-elser.html) model and store them in Elasticsearch. Additionally, we'll explore semantic retrieval techniques and present the top search results as a context window to the Gemma model. Furthermore, we'll utilize the [Hugging Face transformer](https://huggingface.co/google/gemma-2b-it) library to load Gemma on a local environment."
]
},
{
"cell_type": "markdown",
"id": "1bd3acec-d490-4139-bab1-b874e1e7db8d",
"metadata": {},
"source": [
"## Setup"
]
},
{
"cell_type": "markdown",
"id": "ef406b8a-03fb-49c5-baed-18e03bcd36d9",
"metadata": {},
"source": [
"**Elastic Credentials** - Create an [Elastic Cloud deployment](https://www.elastic.co/search-labs/tutorials/install-elasticsearch/elastic-cloud) to get all Elastic credentials (`ELASTIC_CLOUD_ID`,` ELASTIC_API_KEY`).\n",
"\n",
"**Hugging Face Token** - To get started with the [Gemma](https://huggingface.co/google/gemma-2b-it) model, it is necessary to agree to the terms on Hugging Face and generate the [access token](https://huggingface.co/docs/hub/en/security-tokens) with `write` role.\n",
"\n",
"**Gemma Model** - We're going to use [gemma-2b-it](https://huggingface.co/google/gemma-2b-it), though Google has released 4 open models. You can use any of them i.e. [gemma-2b](https://huggingface.co/google/gemma-2b), [gemma-7b](https://huggingface.co/google/gemma-7b), [gemma-7b-it](https://huggingface.co/google/gemma-7b-it)"
]
},
{
"cell_type": "markdown",
"id": "ac91d7a3-1198-4b11-a9c5-50028abc861b",
"metadata": {},
"source": [
"## Install packages"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fda41538-444c-48d7-80a0-b34b2e158b82",
"metadata": {},
"outputs": [],
"source": [
"pip install -q -U elasticsearch langchain transformers huggingface_hub"
]
},
{
"cell_type": "markdown",
"id": "15c2e924-e5a2-439b-8e98-f13a162db7fe",
"metadata": {},
"source": [
"## Import packages"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7219411b-fae6-4c2a-b170-796bc30ed073",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import os\n",
"from getpass import getpass\n",
"from urllib.request import urlopen\n",
"\n",
"from elasticsearch import Elasticsearch, helpers\n",
"from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain.vectorstores import ElasticsearchStore\n",
"from langchain import HuggingFacePipeline\n",
"from langchain.chains import RetrievalQA\n",
"from langchain.prompts import ChatPromptTemplate\n",
"from langchain.schema.output_parser import StrOutputParser\n",
"from langchain.schema.runnable import RunnablePassthrough\n",
"from huggingface_hub import login\n",
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
"from transformers import AutoTokenizer, pipeline"
]
},
{
"cell_type": "markdown",
"id": "182a413f-e7fd-4361-8096-90736d3df33e",
"metadata": {},
"source": [
"## Get Credentials"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b184b3a5-0cc8-43f9-b15d-f5ccf48f574b",
"metadata": {},
"outputs": [],
"source": [
"ELASTIC_API_KEY = getpass(\"Elastic API Key :\")\n",
"ELASTIC_CLOUD_ID = getpass(\"Elastic Cloud ID :\")\n",
"elastic_index_name = \"gemma-rag\""
]
},
{
"cell_type": "markdown",
"id": "a2efbd81-70b9-409c-ab5f-796d538b42a1",
"metadata": {},
"source": [
"## Add documents"
]
},
{
"cell_type": "markdown",
"id": "161dfb9d-f11f-4de5-8489-6464ade0cdb2",
"metadata": {},
"source": [
"### Let's download the sample dataset and deserialize the document."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "49427546-7b37-48f4-a6fe-395736ea2d38",
"metadata": {},
"outputs": [],
"source": [
"url = \"https://raw.githubusercontent.com/ashishtiwari1993/langchain-elasticsearch-RAG/main/data.json\"\n",
"\n",
"response = urlopen(url)\n",
"\n",
"workplace_docs = json.loads(response.read())"
]
},
{
"cell_type": "markdown",
"id": "f3bf0104-8b31-4b39-ad21-b372fd1fa0db",
"metadata": {},
"source": [
"### Split Documents into Passages"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "79e55ed1-418e-48ed-b3e3-d28e10744eb5",
"metadata": {},
"outputs": [],
"source": [
"metadata = []\n",
"content = []\n",
"\n",
"for doc in workplace_docs:\n",
" content.append(doc[\"content\"])\n",
" metadata.append({\n",
" \"name\": doc[\"name\"],\n",
" \"summary\": doc[\"summary\"],\n",
" \"rolePermissions\":doc[\"rolePermissions\"]\n",
" })\n",
"\n",
"text_splitter = CharacterTextSplitter(chunk_size=50, chunk_overlap=0)\n",
"docs = text_splitter.create_documents(content, metadatas=metadata)"
]
},
{
"cell_type": "markdown",
"id": "4264bc1b-23b1-4547-a7f0-670944c3e605",
"metadata": {},
"source": [
"## Index Documents into Elasticsearch using ELSER\n",
"\n",
"Before we begin indexing, ensure you have [downloaded and deployed the ELSER model](https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-elser.html#download-deploy-elser) in your deployment and is running on the ML node."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eb1db78e-e40a-4a5c-9d15-75ee2a1d0994",
"metadata": {},
"outputs": [],
"source": [
"es = ElasticsearchStore.from_documents(\n",
" docs,\n",
" es_cloud_id=ELASTIC_CLOUD_ID,\n",
" es_api_key=ELASTIC_API_KEY,\n",
" index_name=elastic_index_name,\n",
" strategy=ElasticsearchStore.SparseVectorRetrievalStrategy()\n",
")\n",
"\n",
"es"
]
},
{
"cell_type": "markdown",
"id": "02b1ead9-c442-40e9-ba81-d4d286ea878b",
"metadata": {},
"source": [
"## Hugging Face login"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d2f651e4-e760-4b59-a8a3-57c58dfc229f",
"metadata": {},
"outputs": [],
"source": [
"from huggingface_hub import notebook_login\n",
"notebook_login()"
]
},
{
"cell_type": "markdown",
"id": "7454f551-71a9-4310-bb2a-3fe0e683daab",
"metadata": {},
"source": [
"## Initialize the tokenizer with the model (`google/gemma-2b-it`)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3e1d98eb-0f4e-4c41-a851-125b75502963",
"metadata": {},
"outputs": [],
"source": [
"model = AutoModelForCausalLM.from_pretrained(\"google/gemma-2b-it\")\n",
"tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-2b-it\")"
]
},
{
"cell_type": "markdown",
"id": "11a12596-2ac1-4101-b189-d21d53d33b04",
"metadata": {},
"source": [
"## Create a `text-generation` pipeline and initialize with LLM"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "623e74fb-5707-44f7-9dd8-d9499f7ab61e",
"metadata": {},
"outputs": [],
"source": [
"pipe = pipeline(\n",
" \"text-generation\",\n",
" model=model,\n",
" tokenizer=tokenizer,\n",
" max_new_tokens=1024,\n",
")\n",
"\n",
"llm = HuggingFacePipeline(\n",
" pipeline=pipe,\n",
" model_kwargs={\"temperature\": 0.7},\n",
")"
]
},
{
"cell_type": "markdown",
"id": "49ce0e72-e419-4310-85e9-09077d6c40b2",
"metadata": {},
"source": [
"## Format Docs"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b3c07a75-9220-4a82-a92e-3fc2727ad3ba",
"metadata": {},
"outputs": [],
"source": [
"def format_docs(docs):\n",
" print(docs)\n",
" return \"\\n\\n\".join(doc.page_content for doc in docs)"
]
},
{
"cell_type": "markdown",
"id": "f6266222-6ec3-495a-8f14-460549bab89d",
"metadata": {},
"source": [
"## Create a chain using Prompt template"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ec203d1a-104b-4583-9ba1-a6b4b0354367",
"metadata": {},
"outputs": [],
"source": [
"retriever=es.as_retriever(search_kwargs={\"k\": 10})\n",
"\n",
"template = \"\"\"Answer the question based only on the following context:\\n\n",
"\n",
"{context}\n",
"\n",
"Question: {question}\n",
"\"\"\"\n",
"prompt = ChatPromptTemplate.from_template(template)\n",
"\n",
"\n",
"chain = (\n",
" {\"context\": retriever | format_docs, \"question\": RunnablePassthrough()}\n",
" | prompt\n",
" | llm\n",
" | StrOutputParser()\n",
")"
]
},
{
"cell_type": "markdown",
"id": "8ae892dd-7442-4d4d-a804-1d717266e596",
"metadata": {},
"source": [
"## Ask question"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ba312f17-44ae-423d-89a0-ea01eccd85b5",
"metadata": {},
"outputs": [],
"source": [
"q = input(\"Question: \")\n",
"chain.invoke(q)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit df2fdd1

Please sign in to comment.