diff --git a/examples/RAG.ipynb b/examples/RAG.ipynb new file mode 100644 index 00000000..077bf3e1 --- /dev/null +++ b/examples/RAG.ipynb @@ -0,0 +1,211 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 57, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 57, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from dotenv import load_dotenv; load_dotenv()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Downloading builder script: 100%|██████████| 4.03k/4.03k [00:00<00:00, 14.0MB/s]\n", + "Downloading readme: 100%|██████████| 6.84k/6.84k [00:00<00:00, 4.53MB/s]\n", + "Downloading data: 100%|██████████| 81.4M/81.4M [00:09<00:00, 8.96MB/s]\n", + "Generating train split: 100%|██████████| 1600000/1600000 [00:24<00:00, 64850.86 examples/s]\n", + "Generating test split: 100%|██████████| 498/498 [00:00<00:00, 22408.02 examples/s]\n" + ] + } + ], + "source": [ + "# Get Data\n", + "#import csv\n", + "\n", + "\n", + "from datasets import load_dataset\n", + "ds = load_dataset(\"stanfordnlp/sentiment140\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "docs = ds['train'][:2000]['text']" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "# Get our encoder to encode \n", + "from sentence_transformers import SentenceTransformer \n", + "model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')\n", + "data_emb = model.encode(docs) #16 seconds on M1 Mac 8gb\n" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "# Now set up the DB to acceept the data\n", + "\n", + "import chromadb\n", + "\n", + "chroma_client = chromadb.Client()\n", + "collection = chroma_client.create_collection(name=\"SampleDB\")\n", + "\n", + "collection.add(\n", + " embeddings=data_emb.tolist(),\n", + " documents=docs,\n", + " ids=[str(idx) for idx in range(len(data_emb))]) # Doc ID's are requried\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [], + "source": [ + "#prep the question in the encoding space\n", + "\n", + "question = 'What is the status of the chalkboard?'\n", + "question_emb = model.encode(question)\n", + "\n", + "results = collection.query(query_embeddings=question_emb.tolist(), n_results=10)\n", + "\n", + "context = ' '.join(results['documents'][0]) # Pulling out a lists of lists\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": {}, + "outputs": [], + "source": [ + "# Call the models to derterim the answer by response\n", + "\n", + "prompt = f'Given the following data, Please answer the question: \\n\\n ##question \\n {question}\\n\\n ##context \\n {context}'\n", + "\n", + "\n", + "from aimodels.client import MultiFMClient\n", + "client = MultiFMClient()\n", + "\n", + "messages = [\n", + " {\"role\": \"system\", \"content\": \"You are a helpful agent, who answers with brevity. \"},\n", + " {\"role\": \"user\", \"content\": prompt},\n", + "]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The status of the chalkboard is: USELESS (because there is no chalk).\n" + ] + } + ], + "source": [ + "#groq_llama3_8b = \"groq:llama3-8b-8192\" \n", + "response = client.chat.completions.create(model=\"groq:llama3-70b-8192\", messages=messages)\n", + "\n", + "print(response.choices[0].message.content)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\"Damn... I don't have any chalk! MY CHALKBOARD IS USELESS \"" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "results['documents'][0][0]" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Based on the context provided, the status of the chalkboard is useless because the person doesn't have any chalk.\n" + ] + } + ], + "source": [ + "client = MultiFMClient() #have to reinaiatize client or for the API key. \n", + "\n", + "response = client.chat.completions.create(model=\"anthropic:claude-3-opus-20240229\", messages=messages)\n", + "print(response.choices[0].message.content)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.1" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}