Skip to content

Commit

Permalink
perf: update ingestion, add todo for retriever
Browse files Browse the repository at this point in the history
  • Loading branch information
datvodinh committed May 11, 2024
1 parent d7d71ec commit ac454df
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 4 deletions.
3 changes: 2 additions & 1 deletion rag_chatbot/core/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .engine import LocalChatEngine

from .retriever import LocalRetriever
__all__ = [
"LocalChatEngine",
"LocalRetriever"
]
3 changes: 2 additions & 1 deletion rag_chatbot/core/engine/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

class TwoStageRetriever:
def __init__(self) -> None:
# TODO
pass


Expand All @@ -43,7 +44,7 @@ def _get_two_stage_retriever(
similarity_top_k=self._setting.retriever.similarity_top_k,
embed_model=Settings.embed_model,
verbose=True
)
) # TODO

def _get_fusion_retriever(
self,
Expand Down
9 changes: 7 additions & 2 deletions rag_chatbot/core/ingestion/ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from llama_index.core.schema import BaseNode
from llama_index.core.node_parser import SentenceSplitter
from dotenv import load_dotenv
from typing import List
from typing import Any, List
from tqdm import tqdm
from ...setting import RAGSettings

Expand All @@ -19,6 +19,8 @@ def __init__(self, setting: RAGSettings | None = None) -> None:
def store_nodes(
self,
input_files: list[str],
embed_nodes: bool = True,
embed_model: Any | None = None
) -> List[BaseNode]:
splitter = SentenceSplitter.from_defaults(
chunk_size=self._setting.ingestion.chunk_size,
Expand All @@ -32,6 +34,8 @@ def store_nodes(
]
return_nodes = []
self._ingested_file = []
if embed_nodes:
Settings.embed_model = embed_model or Settings.embed_model
for input_file in tqdm(input_files, desc="Ingesting data"):
file_name = input_file.strip().split('/')[-1]
self._ingested_file.append(file_name)
Expand All @@ -50,7 +54,8 @@ def store_nodes(
doc.excluded_llm_metadata_keys = excluded_keys

nodes = splitter(document)
nodes = Settings.embed_model(nodes)
if embed_nodes:
nodes = Settings.embed_model(nodes)
self._node_store[file_name] = nodes
return_nodes.extend(nodes)

Expand Down
3 changes: 3 additions & 0 deletions rag_chatbot/setting/setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@


class OllamaSettings(BaseModel):
llm: str = Field(
default="llama3:8b-instruct-q8_0", description="LLM model"
)
keep_alive: str = Field(
default="1h", description="Keep alive time for the server"
)
Expand Down

0 comments on commit ac454df

Please sign in to comment.