Skip to content

Commit

Permalink
Merge pull request #89 from lhyscau/feature/openai_embedding
Browse files Browse the repository at this point in the history
完全使用 openai api 的版本实现
  • Loading branch information
yanqiangmiffy authored Jan 19, 2025
2 parents 2aa980d + d55f202 commit 6d6968a
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 20 deletions.
25 changes: 15 additions & 10 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,31 @@
import gradio as gr
import loguru

from trustrag.applications.rag import RagApplication, ApplicationConfig
# from trustrag.applications.rag import RagApplication, ApplicationConfig
from trustrag.applications.rag_openai import RagApplication, ApplicationConfig
from trustrag.modules.reranker.bge_reranker import BgeRerankerConfig
from trustrag.modules.retrieval.dense_retriever import DenseRetrieverConfig

# 修改成自己的配置!!!
app_config = ApplicationConfig()
app_config.docs_path = "/data/users/searchgpt/yq/trustrag/data/docs/"
app_config.llm_model_path = "/data/users/searchgpt/pretrained_models/glm-4-9b-chat"
app_config.docs_path = "data/docs"
app_config.llm_model_path = "gpt-4o"
app_config.base_url = "https://api.openai-up.com/v1"
app_config.api_key = "llm_api_key"

retriever_config = DenseRetrieverConfig(
model_name_or_path="/data/users/searchgpt/pretrained_models/bge-large-zh-v1.5",
dim=1024,
index_path='/data/users/searchgpt/yq/TrustRAG/examples/retrievers/dense_cache'
)
rerank_config = BgeRerankerConfig(
model_name_or_path="/data/users/searchgpt/pretrained_models/bge-reranker-large"
model_name_or_path="text-embedding-ada-002",
dim=1536, #openai_dim
index_path='examples/retrievers/dense_cache',
base_url = "https://api.openai-up.com/v1",
api_key = "embedding_api_key"
)
# rerank_config = BgeRerankerConfig(
# model_name_or_path="/data/users/searchgpt/pretrained_models/bge-reranker-large"
# )

app_config.retriever_config = retriever_config
app_config.rerank_config = rerank_config
# app_config.rerank_config = rerank_config
application = RagApplication(app_config)
application.init_vector_store()

Expand Down
13 changes: 8 additions & 5 deletions trustrag/applications/rag_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,26 @@
"""
import os
from trustrag.modules.document.common_parser import CommonParser
from trustrag.modules.generator.chat import DeepSeekChat
from trustrag.modules.reranker.bge_reranker import BgeReranker
from trustrag.modules.generator.chat import DeepSeekChat, GPT_4o_up
# from trustrag.modules.reranker.bge_reranker import BgeReranker
from trustrag.modules.retrieval.dense_retriever import DenseRetriever
from trustrag.modules.document.chunk import TextChunker

class ApplicationConfig():
def __init__(self):
self.retriever_config = None
self.rerank_config = None
self.api_key = None
self.base_url = None


class RagApplication():
def __init__(self, config):
self.config = config
self.parser = CommonParser()
self.retriever = DenseRetriever(self.config.retriever_config)
self.reranker = BgeReranker(self.config.rerank_config)
self.llm = DeepSeekChat(key=self.config.your_key)
# self.reranker = BgeReranker(self.config.rerank_config)
self.llm = GPT_4o_up(key=self.config.api_key)
self.tc=TextChunker()
self.rag_prompt="""请结合参考的上下文内容回答用户问题,如果上下文不能支撑用户问题,那么回答不知道或者我无法根据参考信息回答。
问题: {question}
Expand Down Expand Up @@ -70,7 +72,7 @@ def add_document(self, file_path):

def chat(self, question: str = '', top_k: int = 5):
contents = self.retriever.retrieve(query=question, top_k=top_k)
contents = self.reranker.rerank(query=question, documents=[content['text'] for content in contents])
# contents = self.reranker.rerank(query=question, documents=[content['text'] for content in contents])
content = '\n'.join([content['text'] for content in contents])

system_prompt = "你是一个人工智能助手."
Expand All @@ -85,4 +87,5 @@ def chat(self, question: str = '', top_k: int = 5):

# 调用 chat 方法进行对话
result = self.llm.chat(system=system_prompt, history=history, gen_conf=gen_conf)
result = result[0] #result[1]是 total_token,在调用 在线 api 时需要这样处理
return result, history, contents
5 changes: 4 additions & 1 deletion trustrag/modules/generator/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepse
if not base_url: base_url="https://api.deepseek.com"
super().__init__(key, model_name, base_url)


class GPT_4o_up(Base):
def __init__(self, key, model_name="gpt-4o", base_url="https://api.openai-up.com/v1"):
if not base_url: base_url="https://api.openai-up.com/v1"
super().__init__(key, model_name, base_url)

class GptTurbo(Base):
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
Expand Down
29 changes: 25 additions & 4 deletions trustrag/modules/retrieval/dense_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import gc
import os
from typing import List,Dict,Union

from openai import OpenAI
import faiss
import numpy as np
from FlagEmbedding import FlagModel
Expand All @@ -37,12 +37,16 @@ def __init__(
model_name_or_path='sentence-transformers/all-mpnet-base-v2',
dim=768,
index_path=None,
batch_size=32
batch_size=32,
api_key=None,
base_url=None
):
self.model_name = model_name_or_path
self.dim = dim
self.index_path = index_path
self.batch_size = batch_size
self.api_key = api_key
self.base_url = base_url

def validate(self):
"""Validate Dense configuration parameters."""
Expand Down Expand Up @@ -72,7 +76,11 @@ class DenseRetriever(BaseRetriever):

def __init__(self, config):
self.config = config
self.model = FlagModel(config.model_name)
# self.model = FlagModel(config.model_name)
self.client = OpenAI(
base_url=config.base_url, # 替换为你的 API 地址
api_key=config.api_key # 替换为你的 API 密钥
)
self.index = faiss.IndexFlatIP(config.dim)
self.dim = config.dim
self.embeddings = []
Expand Down Expand Up @@ -134,7 +142,20 @@ def get_embedding(self, sentences: List[str]) -> np.ndarray:
np.ndarray: A numpy array of embeddings.
"""
# Using configured batch_size
return self.model.encode(sentences=sentences, batch_size=self.batch_size)
# return self.model.encode(sentences=sentences, batch_size=self.batch_size)

#防止chunk为空字符串
sentences = [sentence if sentence else "This is a none string." for sentence in sentences]

response = self.client.embeddings.create(
input=sentences,
model=self.config.model_name
)
embedding = [np.array(item.embedding) for item in response.data]
# 提取嵌入向量
embedding = np.array(embedding)
return embedding

def add_texts(self, texts: List[str]):
"""
Add multiple texts to the index.
Expand Down

0 comments on commit 6d6968a

Please sign in to comment.