Skip to content

Commit

Permalink
feat: Application import and export
Browse files Browse the repository at this point in the history
  • Loading branch information
shaohuzhang1 committed Dec 16, 2024
1 parent 390014f commit 53012c6
Show file tree
Hide file tree
Showing 8 changed files with 260 additions and 24 deletions.
109 changes: 105 additions & 4 deletions apps/application/serializers/application_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import hashlib
import json
import os
import pickle
import re
import uuid
from functools import reduce
Expand All @@ -19,10 +20,10 @@
from django.core import cache, validators
from django.core import signing
from django.db import transaction, models
from django.db.models import QuerySet, Q
from django.db.models import QuerySet
from django.http import HttpResponse
from django.template import Template, Context
from rest_framework import serializers
from rest_framework import serializers, status

from application.flow.workflow_manage import Flow
from application.models import Application, ApplicationDatasetMapping, ApplicationTypeChoices, WorkFlowVersion
Expand All @@ -34,15 +35,17 @@
from common.db.search import get_dynamics_model, native_search, native_page_search
from common.db.sql_execute import select_list
from common.exception.app_exception import AppApiException, NotFound404, AppUnauthorizedFailed
from common.field.common import UploadedImageField
from common.field.common import UploadedImageField, UploadedFileField
from common.models.db_model_manage import DBModelManage
from common.response import result
from common.util.common import valid_license, password_encrypt
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from dataset.models import DataSet, Document, Image
from dataset.serializers.common_serializers import list_paragraph, get_embedding_model_by_dataset_id_list
from embedding.models import SearchMode
from function_lib.serializers.function_lib_serializer import FunctionLibSerializer
from function_lib.models.function import FunctionLib, PermissionType
from function_lib.serializers.function_lib_serializer import FunctionLibSerializer, FunctionLibModelSerializer
from setting.models import AuthOperate
from setting.models.model_management import Model
from setting.models_provider import get_model_credential
Expand All @@ -54,6 +57,13 @@
chat_cache = cache.caches['chat_cache']


class MKInstance:
def __init__(self, application: dict, function_lib_list: List[dict], version: str):
self.application = application
self.function_lib_list = function_lib_list
self.version = version


class ModelDatasetAssociation(serializers.Serializer):
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
Expand Down Expand Up @@ -662,6 +672,72 @@ def edit(self, with_valid=True):
get_application_access_token(application_access_token.access_token, False)
return {**ApplicationSerializer.Query.reset_application(ApplicationSerializerModel(application).data)}

class Import(serializers.Serializer):
file = UploadedFileField(required=True, error_messages=ErrMessage.image("文件"))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))

@valid_license(model=Application, count=5,
message='社区版最多支持 5 个应用,如需拥有更多应用,请联系我们(https://fit2cloud.com/)。')
@transaction.atomic
def import_(self, with_valid=True):
if with_valid:
self.is_valid()
user_id = self.data.get('user_id')
mk_instance_bytes = self.data.get('file').read()
mk_instance = pickle.loads(mk_instance_bytes)
application = mk_instance.application
function_lib_list = mk_instance.function_lib_list
if len(function_lib_list) > 0:
function_lib_id_list = [function_lib.get('id') for function_lib in function_lib_list]
exits_function_lib_id_list = [str(function_lib.id) for function_lib in
QuerySet(FunctionLib).filter(id__in=function_lib_id_list)]
# 获取到需要插入的函数
function_lib_list = [function_lib for function_lib in function_lib_list if
not exits_function_lib_id_list.__contains__(function_lib.get('id'))]
application_model = self.to_application(application, user_id)
function_lib_model_list = [self.to_function_lib(f, user_id) for f in function_lib_list]
application_model.save()
QuerySet(FunctionLib).bulk_create(function_lib_model_list) if len(function_lib_model_list) > 0 else None
return True

@staticmethod
def to_application(application, user_id):
work_flow = application.get('work_flow')
for node in work_flow.get('nodes', []):
if node.get('type') == 'search-dataset-node':
node.get('properties', {}).get('node_data', {})['dataset_id_list'] = []
return Application(id=uuid.uuid1(), user_id=user_id, name=application.get('name'),
desc=application.get('desc'),
prologue=application.get('prologue'), dialogue_number=application.get('dialogue_number'),
dataset_setting=application.get('dataset_setting'),
model_params_setting=application.get('model_params_setting'),
tts_model_params_setting=application.get('tts_model_params_setting'),
problem_optimization=application.get('problem_optimization'),
icon=application.get('icon'),
work_flow=work_flow,
type=application.get('type'),
problem_optimization_prompt=application.get('problem_optimization_prompt'),
tts_model_enable=application.get('tts_model_enable'),
stt_model_enable=application.get('stt_model_enable'),
tts_type=application.get('tts_type'),
clean_time=application.get('clean_time'),
file_upload_enable=application.get('file_upload_enable'),
file_upload_setting=application.get('file_upload_setting'),
)

@staticmethod
def to_function_lib(function_lib, user_id):
"""
@param user_id: 用户id
@param function_lib: 函数库
@return:
"""
return FunctionLib(id=function_lib.get('id'), user_id=user_id, name=function_lib.get('name'),
code=function_lib.get('code'), input_field_list=function_lib.get('input_field_list'),
is_active=function_lib.get('is_active'),
permission_type=PermissionType.PRIVATE)

class Operate(serializers.Serializer):
application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id"))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
Expand Down Expand Up @@ -708,6 +784,31 @@ def delete(self, with_valid=True):
QuerySet(Application).filter(id=self.data.get('application_id')).delete()
return True

def export(self, with_valid=True):
try:
if with_valid:
self.is_valid()
application_id = self.data.get('application_id')
application = QuerySet(Application).filter(id=application_id).first()
function_lib_id_list = [node.get('properties', {}).get('node_data', {}).get('function_lib_id') for node
in
application.work_flow.get('nodes', []) if
node.get('type') == 'function-lib-node']
function_lib_list = []
if len(function_lib_id_list) > 0:
function_lib_list = QuerySet(FunctionLib).filter(id__in=function_lib_id_list)
application_dict = ApplicationSerializerModel(application).data

mk_instance = MKInstance(application_dict,
[FunctionLibModelSerializer(function_lib).data for function_lib in
function_lib_list], 'v1')
application_pickle = pickle.dumps(mk_instance)
response = HttpResponse(content_type='text/plain', content=application_pickle)
response['Content-Disposition'] = f'attachment; filename="{application.name}.mk"'
return response
except Exception as e:
return result.error(str(e), response_status=status.HTTP_500_INTERNAL_SERVER_ERROR)

@transaction.atomic
def publish(self, instance, with_valid=True):
if with_valid:
Expand Down
21 changes: 21 additions & 0 deletions apps/application/swagger_api/application_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,27 @@ def get_request_params_api():
description='应用描述')
]

class Export(ApiMixin):
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='application_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='应用id'),

]

class Import(ApiMixin):
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='file',
in_=openapi.IN_FORM,
type=openapi.TYPE_FILE,
required=True,
description='上传图片文件')
]

class Operate(ApiMixin):
@staticmethod
def get_request_params_api():
Expand Down
2 changes: 2 additions & 0 deletions apps/application/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
app_name = "application"
urlpatterns = [
path('application', views.Application.as_view(), name="application"),
path('application/import', views.Application.Import.as_view()),
path('application/profile', views.Application.Profile.as_view(), name='application/profile'),
path('application/embed', views.Application.Embed.as_view()),
path('application/authentication', views.Application.Authentication.as_view()),
path('application/<str:application_id>/publish', views.Application.Publish.as_view()),
path('application/<str:application_id>/edit_icon', views.Application.EditIcon.as_view()),
path('application/<str:application_id>/export', views.Application.Export.as_view()),
path('application/<str:application_id>/statistics/customer_count',
views.ApplicationStatistics.CustomerCount.as_view()),
path('application/<str:application_id>/statistics/customer_count_trend',
Expand Down
55 changes: 45 additions & 10 deletions apps/application/views/application_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from common.swagger_api.common_api import CommonApi
from common.util.common import query_params_to_single_dict
from dataset.serializers.dataset_serializers import DataSetSerializers
from setting.swagger_api.provide_api import ProvideApi

chat_cache = cache.caches['chat_cache']

Expand Down Expand Up @@ -158,6 +157,34 @@ def put(self, request: Request, application_id: str):
data={'application_id': application_id, 'user_id': request.user.id,
'image': request.FILES.get('file')}).edit(request.data))

class Import(APIView):
authentication_classes = [TokenAuth]
parser_classes = [MultiPartParser]

@action(methods="GET", detail=False)
@swagger_auto_schema(operation_summary="导入应用", operation_id="导入应用",
manual_parameters=ApplicationApi.Import.get_request_params_api(),
tags=["应用"]
)
@has_permissions(RoleConstants.ADMIN, RoleConstants.USER)
def post(self, request: Request):
return result.success(ApplicationSerializer.Import(
data={'user_id': request.user.id, 'file': request.FILES.get('file')}).import_())

class Export(APIView):
authentication_classes = [TokenAuth]

@action(methods="GET", detail=False)
@swagger_auto_schema(operation_summary="导出应用", operation_id="导出应用",
manual_parameters=ApplicationApi.Export.get_request_params_api(),
tags=["应用"]
)
@has_permissions(lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE,
dynamic_tag=keywords.get('application_id')))
def get(self, request: Request, application_id: str):
return ApplicationSerializer.Operate(
data={'application_id': application_id, 'user_id': request.user.id}).export()

class Embed(APIView):
@action(methods=["GET"], detail=False)
@swagger_auto_schema(operation_summary="获取嵌入js",
Expand Down Expand Up @@ -362,7 +389,8 @@ class AccessToken(APIView):
compare=CompareConstants.AND))
def put(self, request: Request, application_id: str):
return result.success(
ApplicationSerializer.AccessTokenSerializer(data={'application_id': application_id}).edit(request.data))
ApplicationSerializer.AccessTokenSerializer(data={'application_id': application_id}).edit(
request.data))

@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取应用 AccessToken信息",
Expand All @@ -382,9 +410,10 @@ def get(self, request: Request, application_id: str):
class Authentication(APIView):
@action(methods=['OPTIONS'], detail=False)
def options(self, request, *args, **kwargs):
return HttpResponse(headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Credentials": "true",
"Access-Control-Allow-Methods": "POST",
"Access-Control-Allow-Headers": "Origin,Content-Type,Cookie,Accept,Token"}, )
return HttpResponse(
headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Credentials": "true",
"Access-Control-Allow-Methods": "POST",
"Access-Control-Allow-Headers": "Origin,Content-Type,Cookie,Accept,Token"}, )

@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="应用认证",
Expand All @@ -404,6 +433,7 @@ def post(self, request: Request):
)

@action(methods=['POST'], detail=False)

@swagger_auto_schema(operation_summary="创建应用",
operation_id="创建应用",
request_body=ApplicationApi.Create.get_request_body_api(),
Expand Down Expand Up @@ -444,7 +474,8 @@ def get(self, request: Request, application_id: str):
"query_text": request.query_params.get("query_text"),
"top_number": request.query_params.get("top_number"),
'similarity': request.query_params.get('similarity'),
'search_mode': request.query_params.get('search_mode')}).hit_test(
'search_mode': request.query_params.get(
'search_mode')}).hit_test(
))

class Publish(APIView):
Expand Down Expand Up @@ -502,7 +533,8 @@ def delete(self, request: Request, application_id: str):
compare=CompareConstants.AND))
def put(self, request: Request, application_id: str):
return result.success(
ApplicationSerializer.Operate(data={'application_id': application_id, 'user_id': request.user.id}).edit(
ApplicationSerializer.Operate(
data={'application_id': application_id, 'user_id': request.user.id}).edit(
request.data))

@action(methods=['GET'], detail=False)
Expand All @@ -528,11 +560,14 @@ class ListApplicationDataSet(APIView):
@swagger_auto_schema(operation_summary="获取当前应用可使用的知识库",
operation_id="获取当前应用可使用的知识库",
manual_parameters=ApplicationApi.Operate.get_request_params_api(),
responses=result.get_api_array_response(DataSetSerializers.Query.get_response_body_api()),
responses=result.get_api_array_response(
DataSetSerializers.Query.get_response_body_api()),
tags=['应用'])
@has_permissions(ViewPermission([RoleConstants.ADMIN, RoleConstants.USER],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
dynamic_tag=keywords.get('application_id'))],
[lambda r, keywords: Permission(group=Group.APPLICATION,
operate=Operate.USE,
dynamic_tag=keywords.get(
'application_id'))],
compare=CompareConstants.AND))
def get(self, request: Request, application_id: str):
return result.success(ApplicationSerializer.Operate(
Expand Down
4 changes: 2 additions & 2 deletions apps/common/response/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,10 @@ def success(data, **kwargs):
return Result(data=data, **kwargs)


def error(message):
def error(message, **kwargs):
"""
获取一个失败的响应对象
:param message: 错误提示
:return: 接口响应对象
"""
return Result(code=500, message=message)
return Result(code=500, message=message, **kwargs)
30 changes: 26 additions & 4 deletions ui/src/api/application.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Result } from '@/request/Result'
import { get, post, postStream, del, put, request, download } from '@/request/index'
import { get, post, postStream, del, put, request, download, exportFile } from '@/request/index'
import type { pageRequest } from '@/api/type/common'
import type { ApplicationFormType } from '@/api/type/application'
import { type Ref } from 'vue'
Expand Down Expand Up @@ -300,7 +300,6 @@ const getApplicationTTIModel: (
return get(`${prefix}/${application_id}/model`, { model_type: 'TTI' }, loading)
}


/**
* 发布应用
* @param 参数
Expand Down Expand Up @@ -377,7 +376,6 @@ const uploadFile: (
return post(`${prefix}/${application_id}/chat/${chat_id}/upload_file`, data, undefined, loading)
}


/**
* 语音转文本
*/
Expand Down Expand Up @@ -503,6 +501,28 @@ const getUserList: (type: string, loading?: Ref<boolean>) => Promise<Result<any>
return get(`/user/list/${type}`, undefined, loading)
}

const exportApplication = (
application_id: string,
application_name: string,
loading?: Ref<boolean>
) => {
return exportFile(
application_name + '.mk',
`/application/${application_id}/export`,
undefined,
loading
)
}

/**
* 导入应用
*/
const importApplication: (data: any, loading?: Ref<boolean>) => Promise<Result<any>> = (
data,
loading
) => {
return post(`${prefix}/import`, data, undefined, loading)
}
export default {
getAllAppilcation,
getApplication,
Expand Down Expand Up @@ -544,5 +564,7 @@ export default {
playDemoText,
getUserList,
getApplicationList,
uploadFile
uploadFile,
exportApplication,
importApplication
}
Loading

0 comments on commit 53012c6

Please sign in to comment.