Skip to content

Commit

Permalink
feat: add speech_to_text node and text_to_speech node
Browse files Browse the repository at this point in the history
  • Loading branch information
wxg0103 committed Dec 13, 2024
1 parent 7bd791f commit 92aec4d
Show file tree
Hide file tree
Showing 36 changed files with 1,063 additions and 42 deletions.
4 changes: 3 additions & 1 deletion apps/application/flow/step_node/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@
from .image_generate_step_node import *

from .search_dataset_node import *
from .speech_to_text_step_node import BaseSpeechToTextNode
from .start_node import *
from .text_to_speech_step_node.impl.base_text_to_speech_node import BaseTextToSpeechNode

node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode,
BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode,
BaseDocumentExtractNode,
BaseImageUnderstandNode, BaseImageGenerateNode, BaseFormNode]
BaseImageUnderstandNode, BaseFormNode, BaseSpeechToTextNode, BaseTextToSpeechNode,BaseImageGenerateNode]


def get_node(node_type):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class ApplicationNodeSerializer(serializers.Serializer):
user_input_field_list = serializers.ListField(required=False, error_messages=ErrMessage.uuid("用户输入字段"))
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片"))
document_list = serializers.ListField(required=False, error_messages=ErrMessage.list("文档"))
audio_list = serializers.ListField(required=False, error_messages=ErrMessage.list("音频"))
child_node = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict("子节点"))
node_data = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict("表单数据"))

Expand Down Expand Up @@ -43,19 +44,30 @@ def _run(self):
app_document_list[1:])
for document in app_document_list:
if 'file_id' not in document:
raise ValueError("参数值错误: 上传的文档中缺少file_id")
raise ValueError("参数值错误: 上传的文档中缺少file_id,文档上传失败")
app_image_list = self.node_params_serializer.data.get('image_list', [])
if app_image_list and len(app_image_list) > 0:
app_image_list = self.workflow_manage.get_reference_field(
app_image_list[0],
app_image_list[1:])
for image in app_image_list:
if 'file_id' not in image:
raise ValueError("参数值错误: 上传的图片中缺少file_id")
raise ValueError("参数值错误: 上传的图片中缺少file_id,图片上传失败")

app_audio_list = self.node_params_serializer.data.get('audio_list', [])
if app_audio_list and len(app_audio_list) > 0:
app_audio_list = self.workflow_manage.get_reference_field(
app_audio_list[0],
app_audio_list[1:])
for audio in app_audio_list:
if 'file_id' not in audio:
raise ValueError("参数值错误: 上传的图片中缺少file_id,音频上传失败")
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data,
app_document_list=app_document_list, app_image_list=app_image_list,
app_audio_list=app_audio_list,
message=str(question), **kwargs)

def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type,
app_document_list=None, app_image_list=None, child_node=None, node_data=None, **kwargs) -> NodeResult:
app_document_list=None, app_image_list=None, app_audio_list=None, child_node=None, node_data=None,
**kwargs) -> NodeResult:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def save_context(self, details, workflow_manage):
self.answer_text = details.get('answer')

def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type,
app_document_list=None, app_image_list=None, child_node=None, node_data=None,
app_document_list=None, app_image_list=None, app_audio_list=None, child_node=None, node_data=None,
**kwargs) -> NodeResult:
from application.serializers.chat_message_serializers import ChatMessageSerializer
# 生成嵌入应用的chat_id
Expand All @@ -167,6 +167,8 @@ def execute(self, application_id, message, chat_id, chat_record_id, stream, re_c
app_document_list = []
if app_image_list is None:
app_image_list = []
if app_audio_list is None:
app_audio_list = []
runtime_node_id = None
record_id = None
child_node_value = None
Expand All @@ -186,6 +188,7 @@ def execute(self, application_id, message, chat_id, chat_record_id, stream, re_c
'client_type': client_type,
'document_list': app_document_list,
'image_list': app_image_list,
'audio_list': app_audio_list,
'runtime_node_id': runtime_node_id,
'chat_record_id': record_id,
'child_node': child_node_value,
Expand Down Expand Up @@ -234,5 +237,6 @@ def get_details(self, index: int, **kwargs):
'global_fields': global_fields,
'document_list': self.workflow_manage.document_list,
'image_list': self.workflow_manage.image_list,
'audio_list': self.workflow_manage.audio_list,
'application_node_dict': self.context.get('application_node_dict')
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# coding=utf-8

from .impl import *
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# coding=utf-8

from typing import Type

from rest_framework import serializers

from application.flow.i_step_node import INode, NodeResult
from common.util.field_message import ErrMessage


class SpeechToTextNodeSerializer(serializers.Serializer):
stt_model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id"))

is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))

audio_list = serializers.ListField(required=False, error_messages=ErrMessage.list("音频"))


class ISpeechToTextNode(INode):
type = 'speech-to-text-node'

def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return SpeechToTextNodeSerializer

def _run(self):
res = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('audio_list')[0],
self.node_params_serializer.data.get('audio_list')[1:])
for audio in res:
if 'file_id' not in audio:
raise ValueError("参数值错误: 上传的图片中缺少file_id,音频上传失败")

return self.execute(audio=res, **self.node_params_serializer.data, **self.flow_params_serializer.data)

def execute(self, stt_model_id, chat_id,
audio,
**kwargs) -> NodeResult:
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# coding=utf-8

from .base_speech_to_text_node import BaseSpeechToTextNode
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# coding=utf-8
import os
import tempfile
import time
import io
from typing import List, Dict

from django.db.models import QuerySet
from pydub import AudioSegment
from concurrent.futures import ThreadPoolExecutor
from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.speech_to_text_step_node.i_speech_to_text_node import ISpeechToTextNode
from common.util.common import split_and_transcribe
from dataset.models import File
from setting.models_provider.tools import get_model_instance_by_model_user_id


class BaseSpeechToTextNode(ISpeechToTextNode):

def save_context(self, details, workflow_manage):
self.context['answer'] = details.get('answer')
self.answer_text = details.get('answer')

def execute(self, stt_model_id, chat_id, audio, **kwargs) -> NodeResult:
stt_model = get_model_instance_by_model_user_id(stt_model_id, self.flow_params_serializer.data.get('user_id'))
audio_list = audio
self.context['audio_list'] = audio


def process_audio_item(audio_item, model):
file = QuerySet(File).filter(id=audio_item['file_id']).first()
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') as temp_file:
temp_file.write(file.get_byte().tobytes())
temp_file_path = temp_file.name
try:
return split_and_transcribe(temp_file_path, model)
finally:
os.remove(temp_file_path)

def process_audio_items(audio_list, model):
with ThreadPoolExecutor(max_workers=5) as executor:
results = list(executor.map(lambda item: process_audio_item(item, model), audio_list))
return '\n\n'.join(results)

result = process_audio_items(audio_list, stt_model)
return NodeResult({'answer': result, 'result': result}, {})

def get_details(self, index: int, **kwargs):
return {
'name': self.node.properties.get('stepName'),
"index": index,
'run_time': self.context.get('run_time'),
'answer': self.context.get('answer'),
'type': self.node.type,
'status': self.status,
'err_message': self.err_message,
'audio_list': self.context.get('audio_list'),
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def save_context(self, details, workflow_manage):
self.context['run_time'] = details.get('run_time')
self.context['document'] = details.get('document_list')
self.context['image'] = details.get('image_list')
self.context['audio'] = details.get('audio_list')
self.status = details.get('status')
self.err_message = details.get('err_message')
for key, value in workflow_variable.items():
Expand All @@ -57,7 +58,8 @@ def execute(self, question, **kwargs) -> NodeResult:
node_variable = {
'question': question,
'image': self.workflow_manage.image_list,
'document': self.workflow_manage.document_list
'document': self.workflow_manage.document_list,
'audio': self.workflow_manage.audio_list
}
return NodeResult(node_variable, workflow_variable)

Expand All @@ -80,5 +82,6 @@ def get_details(self, index: int, **kwargs):
'err_message': self.err_message,
'image_list': self.context.get('image'),
'document_list': self.context.get('document'),
'audio_list': self.context.get('audio'),
'global_fields': global_fields
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# coding=utf-8

from .impl import *
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# coding=utf-8

from typing import Type

from rest_framework import serializers

from application.flow.i_step_node import INode, NodeResult
from common.util.field_message import ErrMessage


class TextToSpeechNodeSerializer(serializers.Serializer):
tts_model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id"))

is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))

content_list = serializers.ListField(required=False, error_messages=ErrMessage.list("文本内容"))
model_params_setting = serializers.DictField(required=False,
error_messages=ErrMessage.integer("模型参数相关设置"))


class ITextToSpeechNode(INode):
type = 'text-to-speech-node'

def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return TextToSpeechNodeSerializer

def _run(self):
content = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('content_list')[0],
self.node_params_serializer.data.get('content_list')[1:])
return self.execute(content=content, **self.node_params_serializer.data, **self.flow_params_serializer.data)

def execute(self, tts_model_id, chat_id,
content, model_params_setting=None,
**kwargs) -> NodeResult:
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# coding=utf-8

from .base_text_to_speech_node import BaseTextToSpeechNode
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# coding=utf-8
import io
import mimetypes

from django.core.files.uploadedfile import InMemoryUploadedFile

from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.image_understand_step_node.i_image_understand_node import IImageUnderstandNode
from application.flow.step_node.text_to_speech_step_node.i_text_to_speech_node import ITextToSpeechNode
from dataset.models import File
from dataset.serializers.file_serializers import FileSerializer
from setting.models_provider.tools import get_model_instance_by_model_user_id


def bytes_to_uploaded_file(file_bytes, file_name="generated_audio.mp3"):
content_type, _ = mimetypes.guess_type(file_name)
if content_type is None:
# 如果未能识别,设置为默认的二进制文件类型
content_type = "application/octet-stream"
# 创建一个内存中的字节流对象
file_stream = io.BytesIO(file_bytes)

# 获取文件大小
file_size = len(file_bytes)

uploaded_file = InMemoryUploadedFile(
file=file_stream,
field_name=None,
name=file_name,
content_type=content_type,
size=file_size,
charset=None,
)
return uploaded_file


class BaseTextToSpeechNode(ITextToSpeechNode):
def save_context(self, details, workflow_manage):
self.context['answer'] = details.get('answer')
self.answer_text = details.get('answer')

def execute(self, tts_model_id, chat_id,
content, model_params_setting=None,
**kwargs) -> NodeResult:
self.context['content'] = content
model = get_model_instance_by_model_user_id(tts_model_id, self.flow_params_serializer.data.get('user_id'),
**model_params_setting)
audio_byte = model.text_to_speech(content)
# 需要把这个音频文件存储到数据库中
file_name = 'generated_audio.mp3'
file = bytes_to_uploaded_file(audio_byte, file_name)
application = self.workflow_manage.work_flow_post_handler.chat_info.application
meta = {
'debug': False if application.id else True,
'chat_id': chat_id,
'application_id': str(application.id) if application.id else None,
}
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
# 拼接一个audio标签的src属性
audio_label = f'<audio src="{file_url}" controls style = "width: 300px; height: 43px" class ="border-r-4"/>'
return NodeResult({'answer': audio_label, 'result': audio_label}, {})

def get_details(self, index: int, **kwargs):
return {
'name': self.node.properties.get('stepName'),
"index": index,
'run_time': self.context.get('run_time'),
'type': self.node.type,
'status': self.status,
'content': self.context.get('content'),
'err_message': self.err_message,
'answer': self.context.get('answer'),
}
6 changes: 5 additions & 1 deletion apps/application/flow/workflow_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, _id: str, _type: str, x: int, y: int, properties: dict, **kwa


end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node', 'application-node',
'image-understand-node', 'image-generate-node']
'image-understand-node', 'speech-to-text-node', 'text-to-speech-node', 'image-generate-node']


class Flow:
Expand Down Expand Up @@ -244,6 +244,7 @@ class WorkflowManage:
def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler,
base_to_response: BaseToResponse = SystemToResponse(), form_data=None, image_list=None,
document_list=None,
audio_list=None,
start_node_id=None,
start_node_data=None, chat_record=None, child_node=None):
if form_data is None:
Expand All @@ -252,11 +253,14 @@ def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandl
image_list = []
if document_list is None:
document_list = []
if audio_list is None:
audio_list = []
self.start_node_id = start_node_id
self.start_node = None
self.form_data = form_data
self.image_list = image_list
self.document_list = document_list
self.audio_list = audio_list
self.params = params
self.flow = flow
self.lock = threading.Lock()
Expand Down
4 changes: 3 additions & 1 deletion apps/application/serializers/chat_message_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ class ChatMessageSerializer(serializers.Serializer):
form_data = serializers.DictField(required=False, error_messages=ErrMessage.char("全局变量"))
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片"))
document_list = serializers.ListField(required=False, error_messages=ErrMessage.list("文档"))
audio_list = serializers.ListField(required=False, error_messages=ErrMessage.list("音频"))
child_node = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict("子节点"))

def is_valid_application_workflow(self, *, raise_exception=False):
Expand Down Expand Up @@ -338,6 +339,7 @@ def chat_work_flow(self, chat_info: ChatInfo, base_to_response):
form_data = self.data.get('form_data')
image_list = self.data.get('image_list')
document_list = self.data.get('document_list')
audio_list = self.data.get('audio_list')
user_id = chat_info.application.user_id
chat_record_id = self.data.get('chat_record_id')
chat_record = None
Expand All @@ -354,7 +356,7 @@ def chat_work_flow(self, chat_info: ChatInfo, base_to_response):
'client_id': client_id,
'client_type': client_type,
'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type),
base_to_response, form_data, image_list, document_list,
base_to_response, form_data, image_list, document_list, audio_list,
self.data.get('runtime_node_id'),
self.data.get('node_data'), chat_record, self.data.get('child_node'))
r = work_flow_manage.run()
Expand Down
Loading

0 comments on commit 92aec4d

Please sign in to comment.