Skip to content

Commit

Permalink
fix bug (#29)
Browse files Browse the repository at this point in the history
1. 0.2.0 相关补充功能
2. bisheng-langchain 的proxy模型bug
  • Loading branch information
yaojin3616 authored Sep 14, 2023
2 parents 075199e + 1e50d9c commit 51a8969
Show file tree
Hide file tree
Showing 18 changed files with 103 additions and 94 deletions.
5 changes: 5 additions & 0 deletions docker/bisheng/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ database_url:
redis_url:
"redis:6379"

# admin 用户配置
admin:
user_name: "admin"
password: "1234"

# 为知识库的embedding进行模型撇脂
knowledges:
embeddings:
Expand Down
6 changes: 6 additions & 0 deletions docker/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ services:

backend:
image: dataelement/bisheng-backend:latest
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:7860/health"]
interval: 1m30s
timeout: 30s
retries: 3
start_period: 30s
volumes:
- ${DOCKER_VOLUME_DIRECTORY:-.}/bisheng/config/config.yaml:/app/bisheng/config.yaml
- ${DOCKER_VOLUME_DIRECTORY:-.}/bisheng/data/:/app/data/
Expand Down
40 changes: 12 additions & 28 deletions src/backend/bisheng/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@


def has_api_terms(word: str):
return 'api' in word and (
'key' in word or ('token' in word and 'tokens' not in word)
)
return 'api' in word and ('key' in word or ('token' in word and 'tokens' not in word))


def remove_api_keys(flow: dict):
Expand All @@ -18,11 +16,7 @@ def remove_api_keys(flow: dict):
node_data = node.get('data').get('node')
template = node_data.get('template')
for value in template.values():
if (
isinstance(value, dict)
and has_api_terms(value['name'])
and value.get('password')
):
if (isinstance(value, dict) and has_api_terms(value['name']) and value.get('password')):
value['value'] = None

return flow
Expand All @@ -32,7 +26,9 @@ def build_input_keys_response(langchain_object, artifacts):
"""Build the input keys response."""

input_keys_response = {
'input_keys': {key: '' for key in langchain_object.input_keys},
'input_keys': {
key: '' for key in langchain_object.input_keys
},
'memory_keys': [],
'handle_keys': artifacts.get('handle_keys', []),
}
Expand All @@ -43,9 +39,7 @@ def build_input_keys_response(langchain_object, artifacts):
input_keys_response['input_keys'][key] = value
# If the object has memory, that memory will have a memory_variables attribute
# memory variables should be removed from the input keys
if hasattr(langchain_object, 'memory') and hasattr(
langchain_object.memory, 'memory_variables'
):
if hasattr(langchain_object, 'memory') and hasattr(langchain_object.memory, 'memory_variables'):
# Remove memory variables from input keys
input_keys_response['input_keys'] = {
key: value
Expand All @@ -55,9 +49,7 @@ def build_input_keys_response(langchain_object, artifacts):
# Add memory variables to memory_keys
input_keys_response['memory_keys'] = langchain_object.memory.memory_variables

if hasattr(langchain_object, 'prompt') and hasattr(
langchain_object.prompt, 'template'
):
if hasattr(langchain_object, 'prompt') and hasattr(langchain_object.prompt, 'template'):
input_keys_response['template'] = langchain_object.prompt.template

return input_keys_response
Expand All @@ -84,9 +76,7 @@ def build_flow(graph_data: dict, artifacts, process_file=False, flow_id=None, ch
# 如果存在文件,当前不操作文件,避免重复操作
if not process_file:
template_dict = {
key: value
for key, value in vertex.data['node']['template'].items()
if isinstance(value, dict)
key: value for key, value in vertex.data['node']['template'].items() if isinstance(value, dict)
}
for key, value in template_dict.items():
if value.get('type') == 'file':
Expand All @@ -103,9 +93,7 @@ def build_flow(graph_data: dict, artifacts, process_file=False, flow_id=None, ch
vertex.build()
params = vertex._built_object_repr()
valid = True
logger.debug(
f"Building node {str(params)[:50]}{'...' if len(str(params)) > 50 else ''}"
)
logger.debug(f"Building node {str(params)[:50]}{'...' if len(str(params)) > 50 else ''}")
if vertex.artifacts:
# The artifacts will be prompt variables
# passed to build_input_keys_response
Expand Down Expand Up @@ -139,16 +127,14 @@ def build_flow_no_yield(graph_data: dict, artifacts, process_file=False, flow_id
graph = Graph.from_payload(graph_data)
except Exception as exc:
logger.exception(exc)
return
raise exc

for i, vertex in enumerate(graph.generator_build(), 1):
try:
# 如果存在文件,当前不操作文件,避免重复操作
if not process_file:
template_dict = {
key: value
for key, value in vertex.data['node']['template'].items()
if isinstance(value, dict)
key: value for key, value in vertex.data['node']['template'].items() if isinstance(value, dict)
}
for key, value in template_dict.items():
if value.get('type') == 'file':
Expand All @@ -164,9 +150,7 @@ def build_flow_no_yield(graph_data: dict, artifacts, process_file=False, flow_id

vertex.build()
params = vertex._built_object_repr()
logger.debug(
f"Building node {str(params)[:50]}{'...' if len(str(params)) > 50 else ''}"
)
logger.debug(f"Building node {str(params)[:50]}{'...' if len(str(params)) > 50 else ''}")
if vertex.artifacts:
# The artifacts will be prompt variables
# passed to build_input_keys_response
Expand Down
1 change: 0 additions & 1 deletion src/backend/bisheng/api/v1/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from bisheng.api.v1.schemas import FlowListCreate, FlowListRead
from bisheng.database.base import get_session
from bisheng.database.models.flow import Flow, FlowCreate, FlowRead, FlowReadWithStyle, FlowUpdate
from bisheng.database.models.template import Template
from bisheng.database.models.user import User
from bisheng.settings import settings
from fastapi import APIRouter, Depends, File, HTTPException, Query, UploadFile
Expand Down
6 changes: 2 additions & 4 deletions src/backend/bisheng/api/v1/knowledge.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import asyncio
from ctypes import Union
import json
import time
from typing import List, Optional
from uuid import uuid4

from sqlalchemy import func

from bisheng.api.v1.schemas import UploadFileResponse
from bisheng.cache.utils import save_uploaded_file
from bisheng.database.base import get_session
Expand All @@ -25,6 +22,7 @@
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import Milvus
from langchain.vectorstores.base import VectorStore
from sqlalchemy import func
from sqlmodel import Session, select

# build router
Expand Down Expand Up @@ -167,7 +165,7 @@ def get_filelist(*, session: Session = Depends(get_session), knowledge_id: int,
files = session.exec(
select(KnowledgeFile).where(KnowledgeFile.knowledge_id == knowledge_id).order_by(
KnowledgeFile.update_time.desc()).offset(page_size * (page_num - 1)).limit(page_size)).all()
return {"data": [jsonable_encoder(knowledgefile) for knowledgefile in files], "total": total_count}
return {'data': [jsonable_encoder(knowledgefile) for knowledgefile in files], 'total': total_count}


@router.delete('/{knowledge_id}', status_code=200)
Expand Down
2 changes: 2 additions & 0 deletions src/backend/bisheng/api/v1/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ async def update_model(endpoint: str, server: str):
elif reason != 'unloaded':
db_model.status = '异常'
db_model.remark = error_translate(reason)
if not db_model.status:
db_model.status = '未上线'
logger.debug(
f'update_status={model_name} rt_status={status} db_status={origin_status} now_status={db_model.status}')
if not db_model.config:
Expand Down
1 change: 0 additions & 1 deletion src/backend/bisheng/api/v1/user.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import hashlib
import json
from typing import List

from sqlalchemy import func

Expand Down
6 changes: 6 additions & 0 deletions src/backend/bisheng/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ database_url:
redis_url:
"redis:6379"

# admin 用户配置
admin:
user_name: "admin"
password: "1234"


# 为知识库的embedding进行模型撇脂
knowledges:
embeddings:
Expand Down
15 changes: 14 additions & 1 deletion src/backend/bisheng/database/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import hashlib

from bisheng.database.models.user import User
from bisheng.settings import settings
from bisheng.utils.logger import logger
from sqlmodel import Session, SQLModel, create_engine
from sqlmodel import Session, SQLModel, create_engine, select

if settings.database_url and settings.database_url.startswith('sqlite'):
connect_args = {'check_same_thread': False}
Expand Down Expand Up @@ -31,6 +34,16 @@ def create_db_and_tables():
else:
logger.debug('Database and tables created successfully')

# 写入默认数据
with Session(engine) as session:
user = session.exec(select(User).limit(1)).all()
if not user:
md5 = hashlib.md5()
md5.update(settings.admin.get('password').encode('utf-8'))
user = User(user_name=settings.admin.get('user_name'), password=md5.hexdigest(), role='admin')
session.add(user)
session.commit()


def get_session():
with Session(engine) as session:
Expand Down
2 changes: 1 addition & 1 deletion src/backend/bisheng/database/models/model_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class ModelDeployBase(SQLModelSerializable):
model: str = Field(index=False)
config: Optional[str] = Field(index=False, sa_column=(String(length=1024)))
status: Optional[str] = Field(index=False)
remark: Optional[str] = Field(index=False)
remark: Optional[str] = Field(index=False, sa_column=(String(length=4096)))

create_time: Optional[datetime] = Field(
sa_column=Column(DateTime, nullable=False, index=True, server_default=text('CURRENT_TIMESTAMP')))
Expand Down
22 changes: 20 additions & 2 deletions src/backend/bisheng/interface/initialize/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@
import os
from typing import Any, Callable, Dict, Type

from bisheng.database.base import get_session
from bisheng.database.models.knowledge import Knowledge
from bisheng.settings import settings
from bisheng_langchain.embeddings.host_embedding import HostEmbeddings
from bisheng_langchain.vectorstores import ElasticKeywordsSearch
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import (FAISS, Chroma, Milvus, MongoDBAtlasVectorSearch, Pinecone,
Qdrant, SupabaseVectorStore, Weaviate)
from sqlmodel import select


def docs_in_params(params: dict) -> bool:
Expand Down Expand Up @@ -203,8 +208,21 @@ def initialize_qdrant(class_object: Type[Qdrant], params: dict):

def initial_milvus(class_object: Type[Milvus], params: dict):
if 'connection_args' not in params:
connection_args = settings.knowledges.get('vectorstores').get('Milvus')
params['connection_args'] = connection_args
params['connection_args'] = settings.knowledges.get('vectorstores').get('Milvus')
if 'embedding' not in params:
# 匹配知识库的embedding
col = params['collection_name']
with get_session() as session:
knowledge = session.exec(select(Knowledge).where(Knowledge.collection_name == col)).first()
if not knowledge:
raise Exception(f'不能找到知识库collection={col}')
model_param = settings.knowledges.get('embeddings').get(knowledge.model)
if Knowledge.model == 'text-embedding-ada-002':
embedding = OpenAIEmbeddings(**model_param)
else:
embedding = HostEmbeddings(**model_param)
params['embedding'] = embedding

elif isinstance(params.get('connection_args'), str):
print(f"milvus before params={params} type={type(params['connection_args'])}")
params['connection_args'] = json.loads(params.pop('connection_args'))
Expand Down
2 changes: 2 additions & 0 deletions src/backend/bisheng/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class Settings(BaseSettings):
dev: bool = False
database_url: Optional[str] = None
redis_url: Optional[str] = None
admin: dict = {}
cache: str = 'InMemoryCache'
remove_api_keys: bool = False

Expand Down Expand Up @@ -71,6 +72,7 @@ def update_from_yaml(self, file_path: str, dev: bool = False):
self.retrievers = new_settings.retrievers or {}
self.output_parsers = new_settings.output_parsers or {}
self.input_output = new_settings.input_output or {}
self.admin = new_settings.admin or {}
self.dev = dev

def update_settings(self, **kwargs):
Expand Down
Loading

0 comments on commit 51a8969

Please sign in to comment.