Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add RAG for general campus queries #4

Merged
merged 10 commits into from
Jan 18, 2025
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
.env
__pycache__/
.langgraph_api
chroma
Chaitanya-Keyal marked this conversation as resolved.
Show resolved Hide resolved
.vscode/
4,356 changes: 4,236 additions & 120 deletions poetry.lock

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ readme = "README.md"
package-mode = false

[tool.poetry.dependencies]
python = "^3.11"
python = ">=3.11,<3.13"
python-dotenv = "^1.0.1"
langgraph = "^0.2.60"
langsmith = "^0.2.7"
Expand All @@ -17,6 +17,11 @@ pre-commit = "^4.0.1"
psycopg2-binary = "^2.9.10"
langchain-core = "^0.3.29"
langchain = "^0.3.14"
chromadb = "^0.6.2"
langchain-community = "^0.3.14"
sentence-transformers = "^3.3.1"
unstructured = {extras = ["md", "pdf"], version = "^0.16.12"}
langchain-huggingface = "^0.1.2"


[build-system]
Expand Down
16 changes: 10 additions & 6 deletions src/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@

from dotenv import load_dotenv
from langchain.agents import AgentExecutor, create_tool_calling_agent
from langchain.prompts import ChatPromptTemplate
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain_groq import ChatGroq

from src.campus_rag.rag_chain_components import get_retriever, output_parser
from src.tools.memory_tool import tool_modify_memory

load_dotenv()
Expand Down Expand Up @@ -63,19 +66,20 @@ def intent_classifier(self, query: str, chat_history: str) -> str:
return result

def general_campus_query(self, query: str, chat_history: str) -> str:
prompt = self._get_prompt(
PROMPT_TEMPLATE = self._get_prompt(
"GENERAL_CAMPUS_QUERY_AGENT",
f"<query>{query}</query>\n\n<history>{chat_history}</history>",
)

chain = prompt | self.llm
vectorstore_retriever = get_retriever()

result = chain.invoke(
{
"input": query,
}
setup_and_retrieval = RunnableParallel(
{"context": vectorstore_retriever, "question": RunnablePassthrough()}
)

chain = setup_and_retrieval | PROMPT_TEMPLATE | self.llm | output_parser

result = chain.invoke(query)
return result

def course_query(self, query: str, chat_history: str) -> str:
Expand Down
90 changes: 90 additions & 0 deletions src/campus_rag/create_database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import os
import shutil

from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import (
DirectoryLoader,
SitemapLoader,
WebBaseLoader,
)
from langchain_community.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings

model_id = "sentence-transformers/all-MiniLM-L6-v2"
model_kwargs = {"device": "cpu"}
embeddings = HuggingFaceEmbeddings(model_name=model_id, model_kwargs=model_kwargs)

CHROMA_PATH = "chroma"
DATA_PATH = "data"


def load_documents(extension="txt") -> list[Document]:
print(f"Loading {extension} documents...")
loader = DirectoryLoader(DATA_PATH, glob=f"**/*.{extension}")
documents = loader.load()
return documents


def load_web_documents() -> list[Document]:
# Sitemap Loader is taking forever and failing because of TooManyRedirects
loader = SitemapLoader(
web_path="https://www.bits-pilani.ac.in/campus-sitemap.xml",
filter_urls=["https://www.bits-pilani.ac.in/hyderabad"],
continue_on_failure=True,
) # filter for only hyderabad related pages

# loader = WebBaseLoader("https://www.bits-pilani.ac.in/hyderabad/")
documents = loader.load()
return documents


def split_text(documents: list[Document]):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=300,
length_function=len,
add_start_index=True,
)
chunks = text_splitter.split_documents(documents)
print(f"Split {len(documents)} documents into {len(chunks)} chunks.")

return chunks


def save_to_chroma(chunks: list[Document]):
if os.path.exists(CHROMA_PATH):
shutil.rmtree(CHROMA_PATH)

db = Chroma.from_documents(chunks, embeddings, persist_directory=CHROMA_PATH)
db.persist()
print(f"Saved {len(chunks)} chunks to {CHROMA_PATH}.")


def generate_data_store():
print(f"Loading documents in {DATA_PATH}...")

documents = load_documents("md")
chunks = split_text(documents)

documents = load_documents("pdf")
chunks.extend(split_text(documents))

documents = load_documents("txt")
chunks.extend(split_text(documents))

# print("Loading web documents...")
# web_documents = load_web_documents()
# web_chunks = split_text(web_documents)

# chunks.extend(web_chunks)
print("Saving to Chroma...")
save_to_chroma(chunks)


def main():
generate_data_store()


if __name__ == "__main__":
main()
Binary file not shown.
Loading
Loading