diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 7c43ada4b68..1af1f48beea 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -10,6 +10,7 @@ import hashlib import json import os +import pickle import re import uuid from functools import reduce @@ -19,10 +20,10 @@ from django.core import cache, validators from django.core import signing from django.db import transaction, models -from django.db.models import QuerySet, Q +from django.db.models import QuerySet from django.http import HttpResponse from django.template import Template, Context -from rest_framework import serializers +from rest_framework import serializers, status from application.flow.workflow_manage import Flow from application.models import Application, ApplicationDatasetMapping, ApplicationTypeChoices, WorkFlowVersion @@ -34,15 +35,17 @@ from common.db.search import get_dynamics_model, native_search, native_page_search from common.db.sql_execute import select_list from common.exception.app_exception import AppApiException, NotFound404, AppUnauthorizedFailed -from common.field.common import UploadedImageField +from common.field.common import UploadedImageField, UploadedFileField from common.models.db_model_manage import DBModelManage +from common.response import result from common.util.common import valid_license, password_encrypt from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from dataset.models import DataSet, Document, Image from dataset.serializers.common_serializers import list_paragraph, get_embedding_model_by_dataset_id_list from embedding.models import SearchMode -from function_lib.serializers.function_lib_serializer import FunctionLibSerializer +from function_lib.models.function import FunctionLib, PermissionType +from function_lib.serializers.function_lib_serializer import FunctionLibSerializer, FunctionLibModelSerializer from setting.models import AuthOperate from setting.models.model_management import Model from setting.models_provider import get_model_credential @@ -54,6 +57,13 @@ chat_cache = cache.caches['chat_cache'] +class MKInstance: + def __init__(self, application: dict, function_lib_list: List[dict], version: str): + self.application = application + self.function_lib_list = function_lib_list + self.version = version + + class ModelDatasetAssociation(serializers.Serializer): user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, @@ -662,6 +672,72 @@ def edit(self, with_valid=True): get_application_access_token(application_access_token.access_token, False) return {**ApplicationSerializer.Query.reset_application(ApplicationSerializerModel(application).data)} + class Import(serializers.Serializer): + file = UploadedFileField(required=True, error_messages=ErrMessage.image("文件")) + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) + + @valid_license(model=Application, count=5, + message='社区版最多支持 5 个应用,如需拥有更多应用,请联系我们(https://fit2cloud.com/)。') + @transaction.atomic + def import_(self, with_valid=True): + if with_valid: + self.is_valid() + user_id = self.data.get('user_id') + mk_instance_bytes = self.data.get('file').read() + mk_instance = pickle.loads(mk_instance_bytes) + application = mk_instance.application + function_lib_list = mk_instance.function_lib_list + if len(function_lib_list) > 0: + function_lib_id_list = [function_lib.get('id') for function_lib in function_lib_list] + exits_function_lib_id_list = [str(function_lib.id) for function_lib in + QuerySet(FunctionLib).filter(id__in=function_lib_id_list)] + # 获取到需要插入的函数 + function_lib_list = [function_lib for function_lib in function_lib_list if + not exits_function_lib_id_list.__contains__(function_lib.get('id'))] + application_model = self.to_application(application, user_id) + function_lib_model_list = [self.to_function_lib(f, user_id) for f in function_lib_list] + application_model.save() + QuerySet(FunctionLib).bulk_create(function_lib_model_list) if len(function_lib_model_list) > 0 else None + return True + + @staticmethod + def to_application(application, user_id): + work_flow = application.get('work_flow') + for node in work_flow.get('nodes', []): + if node.get('type') == 'search-dataset-node': + node.get('properties', {}).get('node_data', {})['dataset_id_list'] = [] + return Application(id=uuid.uuid1(), user_id=user_id, name=application.get('name'), + desc=application.get('desc'), + prologue=application.get('prologue'), dialogue_number=application.get('dialogue_number'), + dataset_setting=application.get('dataset_setting'), + model_params_setting=application.get('model_params_setting'), + tts_model_params_setting=application.get('tts_model_params_setting'), + problem_optimization=application.get('problem_optimization'), + icon=application.get('icon'), + work_flow=work_flow, + type=application.get('type'), + problem_optimization_prompt=application.get('problem_optimization_prompt'), + tts_model_enable=application.get('tts_model_enable'), + stt_model_enable=application.get('stt_model_enable'), + tts_type=application.get('tts_type'), + clean_time=application.get('clean_time'), + file_upload_enable=application.get('file_upload_enable'), + file_upload_setting=application.get('file_upload_setting'), + ) + + @staticmethod + def to_function_lib(function_lib, user_id): + """ + + @param user_id: 用户id + @param function_lib: 函数库 + @return: + """ + return FunctionLib(id=function_lib.get('id'), user_id=user_id, name=function_lib.get('name'), + code=function_lib.get('code'), input_field_list=function_lib.get('input_field_list'), + is_active=function_lib.get('is_active'), + permission_type=PermissionType.PRIVATE) + class Operate(serializers.Serializer): application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id")) user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) @@ -708,6 +784,31 @@ def delete(self, with_valid=True): QuerySet(Application).filter(id=self.data.get('application_id')).delete() return True + def export(self, with_valid=True): + try: + if with_valid: + self.is_valid() + application_id = self.data.get('application_id') + application = QuerySet(Application).filter(id=application_id).first() + function_lib_id_list = [node.get('properties', {}).get('node_data', {}).get('function_lib_id') for node + in + application.work_flow.get('nodes', []) if + node.get('type') == 'function-lib-node'] + function_lib_list = [] + if len(function_lib_id_list) > 0: + function_lib_list = QuerySet(FunctionLib).filter(id__in=function_lib_id_list) + application_dict = ApplicationSerializerModel(application).data + + mk_instance = MKInstance(application_dict, + [FunctionLibModelSerializer(function_lib).data for function_lib in + function_lib_list], 'v1') + application_pickle = pickle.dumps(mk_instance) + response = HttpResponse(content_type='text/plain', content=application_pickle) + response['Content-Disposition'] = f'attachment; filename="{application.name}.mk"' + return response + except Exception as e: + return result.error(str(e), response_status=status.HTTP_500_INTERNAL_SERVER_ERROR) + @transaction.atomic def publish(self, instance, with_valid=True): if with_valid: diff --git a/apps/application/swagger_api/application_api.py b/apps/application/swagger_api/application_api.py index d05fbb04780..d8bdf79b88c 100644 --- a/apps/application/swagger_api/application_api.py +++ b/apps/application/swagger_api/application_api.py @@ -336,6 +336,27 @@ def get_request_params_api(): description='应用描述') ] + class Export(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='application_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='应用id'), + + ] + + class Import(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='file', + in_=openapi.IN_FORM, + type=openapi.TYPE_FILE, + required=True, + description='上传图片文件') + ] + class Operate(ApiMixin): @staticmethod def get_request_params_api(): diff --git a/apps/application/urls.py b/apps/application/urls.py index b4339ba8421..f4a6f932117 100644 --- a/apps/application/urls.py +++ b/apps/application/urls.py @@ -5,11 +5,13 @@ app_name = "application" urlpatterns = [ path('application', views.Application.as_view(), name="application"), + path('application/import', views.Application.Import.as_view()), path('application/profile', views.Application.Profile.as_view(), name='application/profile'), path('application/embed', views.Application.Embed.as_view()), path('application/authentication', views.Application.Authentication.as_view()), path('application//publish', views.Application.Publish.as_view()), path('application//edit_icon', views.Application.EditIcon.as_view()), + path('application//export', views.Application.Export.as_view()), path('application//statistics/customer_count', views.ApplicationStatistics.CustomerCount.as_view()), path('application//statistics/customer_count_trend', diff --git a/apps/application/views/application_views.py b/apps/application/views/application_views.py index f0873d62c74..dd16e2c65fe 100644 --- a/apps/application/views/application_views.py +++ b/apps/application/views/application_views.py @@ -27,7 +27,6 @@ from common.swagger_api.common_api import CommonApi from common.util.common import query_params_to_single_dict from dataset.serializers.dataset_serializers import DataSetSerializers -from setting.swagger_api.provide_api import ProvideApi chat_cache = cache.caches['chat_cache'] @@ -158,6 +157,34 @@ def put(self, request: Request, application_id: str): data={'application_id': application_id, 'user_id': request.user.id, 'image': request.FILES.get('file')}).edit(request.data)) + class Import(APIView): + authentication_classes = [TokenAuth] + parser_classes = [MultiPartParser] + + @action(methods="GET", detail=False) + @swagger_auto_schema(operation_summary="导入应用", operation_id="导入应用", + manual_parameters=ApplicationApi.Import.get_request_params_api(), + tags=["应用"] + ) + @has_permissions(RoleConstants.ADMIN, RoleConstants.USER) + def post(self, request: Request): + return result.success(ApplicationSerializer.Import( + data={'user_id': request.user.id, 'file': request.FILES.get('file')}).import_()) + + class Export(APIView): + authentication_classes = [TokenAuth] + + @action(methods="GET", detail=False) + @swagger_auto_schema(operation_summary="导出应用", operation_id="导出应用", + manual_parameters=ApplicationApi.Export.get_request_params_api(), + tags=["应用"] + ) + @has_permissions(lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE, + dynamic_tag=keywords.get('application_id'))) + def get(self, request: Request, application_id: str): + return ApplicationSerializer.Operate( + data={'application_id': application_id, 'user_id': request.user.id}).export() + class Embed(APIView): @action(methods=["GET"], detail=False) @swagger_auto_schema(operation_summary="获取嵌入js", @@ -362,7 +389,8 @@ class AccessToken(APIView): compare=CompareConstants.AND)) def put(self, request: Request, application_id: str): return result.success( - ApplicationSerializer.AccessTokenSerializer(data={'application_id': application_id}).edit(request.data)) + ApplicationSerializer.AccessTokenSerializer(data={'application_id': application_id}).edit( + request.data)) @action(methods=['GET'], detail=False) @swagger_auto_schema(operation_summary="获取应用 AccessToken信息", @@ -382,9 +410,10 @@ def get(self, request: Request, application_id: str): class Authentication(APIView): @action(methods=['OPTIONS'], detail=False) def options(self, request, *args, **kwargs): - return HttpResponse(headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Credentials": "true", - "Access-Control-Allow-Methods": "POST", - "Access-Control-Allow-Headers": "Origin,Content-Type,Cookie,Accept,Token"}, ) + return HttpResponse( + headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Credentials": "true", + "Access-Control-Allow-Methods": "POST", + "Access-Control-Allow-Headers": "Origin,Content-Type,Cookie,Accept,Token"}, ) @action(methods=['POST'], detail=False) @swagger_auto_schema(operation_summary="应用认证", @@ -404,6 +433,7 @@ def post(self, request: Request): ) @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="创建应用", operation_id="创建应用", request_body=ApplicationApi.Create.get_request_body_api(), @@ -444,7 +474,8 @@ def get(self, request: Request, application_id: str): "query_text": request.query_params.get("query_text"), "top_number": request.query_params.get("top_number"), 'similarity': request.query_params.get('similarity'), - 'search_mode': request.query_params.get('search_mode')}).hit_test( + 'search_mode': request.query_params.get( + 'search_mode')}).hit_test( )) class Publish(APIView): @@ -502,7 +533,8 @@ def delete(self, request: Request, application_id: str): compare=CompareConstants.AND)) def put(self, request: Request, application_id: str): return result.success( - ApplicationSerializer.Operate(data={'application_id': application_id, 'user_id': request.user.id}).edit( + ApplicationSerializer.Operate( + data={'application_id': application_id, 'user_id': request.user.id}).edit( request.data)) @action(methods=['GET'], detail=False) @@ -528,11 +560,14 @@ class ListApplicationDataSet(APIView): @swagger_auto_schema(operation_summary="获取当前应用可使用的知识库", operation_id="获取当前应用可使用的知识库", manual_parameters=ApplicationApi.Operate.get_request_params_api(), - responses=result.get_api_array_response(DataSetSerializers.Query.get_response_body_api()), + responses=result.get_api_array_response( + DataSetSerializers.Query.get_response_body_api()), tags=['应用']) @has_permissions(ViewPermission([RoleConstants.ADMIN, RoleConstants.USER], - [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, - dynamic_tag=keywords.get('application_id'))], + [lambda r, keywords: Permission(group=Group.APPLICATION, + operate=Operate.USE, + dynamic_tag=keywords.get( + 'application_id'))], compare=CompareConstants.AND)) def get(self, request: Request, application_id: str): return result.success(ApplicationSerializer.Operate( diff --git a/apps/common/response/result.py b/apps/common/response/result.py index bb2ba0fafe5..3dc6bc35d35 100644 --- a/apps/common/response/result.py +++ b/apps/common/response/result.py @@ -157,10 +157,10 @@ def success(data, **kwargs): return Result(data=data, **kwargs) -def error(message): +def error(message, **kwargs): """ 获取一个失败的响应对象 :param message: 错误提示 :return: 接口响应对象 """ - return Result(code=500, message=message) + return Result(code=500, message=message, **kwargs) diff --git a/ui/src/api/application.ts b/ui/src/api/application.ts index 404303aebd9..5b6712ce8d0 100644 --- a/ui/src/api/application.ts +++ b/ui/src/api/application.ts @@ -1,5 +1,5 @@ import { Result } from '@/request/Result' -import { get, post, postStream, del, put, request, download } from '@/request/index' +import { get, post, postStream, del, put, request, download, exportFile } from '@/request/index' import type { pageRequest } from '@/api/type/common' import type { ApplicationFormType } from '@/api/type/application' import { type Ref } from 'vue' @@ -300,7 +300,6 @@ const getApplicationTTIModel: ( return get(`${prefix}/${application_id}/model`, { model_type: 'TTI' }, loading) } - /** * 发布应用 * @param 参数 @@ -377,7 +376,6 @@ const uploadFile: ( return post(`${prefix}/${application_id}/chat/${chat_id}/upload_file`, data, undefined, loading) } - /** * 语音转文本 */ @@ -503,6 +501,28 @@ const getUserList: (type: string, loading?: Ref) => Promise return get(`/user/list/${type}`, undefined, loading) } +const exportApplication = ( + application_id: string, + application_name: string, + loading?: Ref +) => { + return exportFile( + application_name + '.mk', + `/application/${application_id}/export`, + undefined, + loading + ) +} + +/** + * 导入应用 + */ +const importApplication: (data: any, loading?: Ref) => Promise> = ( + data, + loading +) => { + return post(`${prefix}/import`, data, undefined, loading) +} export default { getAllAppilcation, getApplication, @@ -544,5 +564,7 @@ export default { playDemoText, getUserList, getApplicationList, - uploadFile + uploadFile, + exportApplication, + importApplication } diff --git a/ui/src/request/index.ts b/ui/src/request/index.ts index 050b41e5dbf..7d65d65f2b8 100644 --- a/ui/src/request/index.ts +++ b/ui/src/request/index.ts @@ -227,7 +227,6 @@ export const exportExcel: ( ) => { return promise(request({ url: url, method: 'get', params, responseType: 'blob' }), loading).then( (res: any) => { - console.log(res) if (res) { const blob = new Blob([res], { type: 'application/vnd.ms-excel' @@ -244,6 +243,35 @@ export const exportExcel: ( ) } +export const exportFile: ( + fileName: string, + url: string, + params: any, + loading?: NProgress | Ref +) => Promise = ( + fileName: string, + url: string, + params: any, + loading?: NProgress | Ref +) => { + return promise(request({ url: url, method: 'get', params, responseType: 'blob' }), loading).then( + (res: any) => { + if (res) { + const blob = new Blob([res], { + type: 'application/octet-stream' + }) + const link = document.createElement('a') + link.href = window.URL.createObjectURL(blob) + link.download = fileName + link.click() + //释放内存 + window.URL.revokeObjectURL(link.href) + } + return true + } + ) +} + export const exportExcelPost: ( fileName: string, url: string, diff --git a/ui/src/views/application/index.vue b/ui/src/views/application/index.vue index 97eba9c05f3..468c2fbd2b2 100644 --- a/ui/src/views/application/index.vue +++ b/ui/src/views/application/index.vue @@ -3,6 +3,18 @@

{{ $t('views.application.applicationList.title') }}

+ + 导入应用 + 复制 - + + 导出 + {{ $t('views.application.applicationList.card.delete.tooltip') }} @@ -152,7 +166,7 @@ import { ref, onMounted, reactive } from 'vue' import applicationApi from '@/api/application' import CreateApplicationDialog from './component/CreateApplicationDialog.vue' import CopyApplicationDialog from './component/CopyApplicationDialog.vue' -import { MsgSuccess, MsgConfirm, MsgAlert } from '@/utils/message' +import { MsgSuccess, MsgConfirm, MsgAlert, MsgError } from '@/utils/message' import { isAppIcon } from '@/utils/application' import { useRouter } from 'vue-router' import { isWorkFlow } from '@/utils/application' @@ -203,7 +217,20 @@ function settingApplication(row: any) { router.push({ path: `/application/${row.id}/${row.type}/setting` }) } } - +const exportApplication = (application: any) => { + applicationApi.exportApplication(application.id, application.name, loading).catch((e) => { + e.response.data.text().then((res: string) => { + MsgError(`导出失败:${JSON.parse(res).message}`) + }) + }) +} +const importApplication = (file: any) => { + const formData = new FormData() + formData.append('file', file.raw, file.name) + applicationApi.importApplication(formData, loading).then((ok) => { + searchHandle() + }) +} function openCreateDialog() { if (user.isEnterprise()) { CreateApplicationDialogRef.value.open()