-
Notifications
You must be signed in to change notification settings - Fork 0
/
streamlit_app.py
125 lines (108 loc) · 5.28 KB
/
streamlit_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import streamlit as st
from langchain_openai import ChatOpenAI
import os
from langchain_core.output_parsers import StrOutputParser
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from zhipuai_llm import ZhipuAILLM
import sys
sys.path.append("../C3 搭建知识库") # 将父目录放入系统路径中
from zhipuai_embedding import ZhipuAIEmbeddings
from langchain.vectorstores.chroma import Chroma
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv()) # read local .env file
#export OPENAI_API_KEY=
#os.environ["OPENAI_API_BASE"] = 'https://api.chatgptid.net/v1'
zhipuai_api_key = "f63573f65d0be7d636627d22723cd8f2.fi65qV3t9Ez4iyxI"
def generate_response(input_text, openai_api_key):
# llm = ChatOpenAI(temperature=0.7, openai_api_key=openai_api_key)
llm = ZhipuAILLM(temperature=0.7, api_key=zhipuai_api_key)
output = llm.invoke(input_text)
output_parser = StrOutputParser()
output = output_parser.invoke(output)
#st.info(output)
return output
def get_vectordb():
# 定义 Embeddings
embedding = ZhipuAIEmbeddings()
# 向量数据库持久化路径
persist_directory = '../C3 搭建知识库/data_base/vector_db/chroma'
# 加载数据库
vectordb = Chroma(
persist_directory=persist_directory, # 允许我们将persist_directory目录保存到磁盘上
embedding_function=embedding
)
return vectordb
#带有历史记录的问答链
def get_chat_qa_chain(question:str,openai_api_key:str):
vectordb = get_vectordb()
# llm = ChatOpenAI(model_name = "gpt-3.5-turbo", temperature = 0,openai_api_key = openai_api_key)
llm = ZhipuAILLM(model = "glm-4", temperature = 0.1, api_key = zhipuai_api_key)
memory = ConversationBufferMemory(
memory_key="chat_history", # 与 prompt 的输入变量保持一致。
return_messages=True # 将以消息列表的形式返回聊天记录,而不是单个字符串
)
retriever=vectordb.as_retriever()
qa = ConversationalRetrievalChain.from_llm(
llm,
retriever=retriever,
memory=memory
)
result = qa({"question": question})
return result['answer']
#不带历史记录的问答链
def get_qa_chain(question:str,openai_api_key:str):
vectordb = get_vectordb()
# llm = ChatOpenAI(model_name = "gpt-3.5-turbo", temperature = 0,openai_api_key = openai_api_key)
llm = ZhipuAILLM(model = "glm-4", temperature = 0.1, api_key = zhipuai_api_key)
template = """使用以下上下文来回答最后的问题。如果你不知道答案,就说你不知道,不要试图编造答
案。最多使用三句话。尽量使答案简明扼要。总是在回答的最后说“谢谢你的提问!”。
{context}
问题: {question}
"""
QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context","question"],
template=template)
qa_chain = RetrievalQA.from_chain_type(llm,
retriever=vectordb.as_retriever(),
return_source_documents=True,
chain_type_kwargs={"prompt":QA_CHAIN_PROMPT})
result = qa_chain({"query": question})
return result["result"]
# Streamlit 应用程序界面
def main():
st.title('🦜🔗 动手学大模型应用开发')
openai_api_key = st.sidebar.text_input('OpenAI API Key', type='password')
# 添加一个选择按钮来选择不同的模型
#selected_method = st.sidebar.selectbox("选择模式", ["qa_chain", "chat_qa_chain", "None"])
selected_method = st.radio(
"你想选择哪种模式进行对话?",
["None", "qa_chain", "chat_qa_chain"],
captions = ["不使用检索问答的普通模式", "不带历史记录的检索问答模式", "带历史记录的检索问答模式"])
# 用于跟踪对话历史
if 'messages' not in st.session_state:
st.session_state.messages = []
messages = st.container(height=300)
if prompt := st.chat_input("Say something"):
# 将用户输入添加到对话历史中
st.session_state.messages.append({"role": "user", "text": prompt})
if selected_method == "None":
# 调用 respond 函数获取回答
answer = generate_response(prompt, openai_api_key)
elif selected_method == "qa_chain":
answer = get_qa_chain(prompt,openai_api_key)
elif selected_method == "chat_qa_chain":
answer = get_chat_qa_chain(prompt,openai_api_key)
# 检查回答是否为 None
if answer is not None:
# 将LLM的回答添加到对话历史中
st.session_state.messages.append({"role": "assistant", "text": answer})
# 显示整个对话历史
for message in st.session_state.messages:
if message["role"] == "user":
messages.chat_message("user").write(message["text"])
elif message["role"] == "assistant":
messages.chat_message("assistant").write(message["text"])
if __name__ == "__main__":
main()