-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add speech_to_text node and text_to_speech node
- Loading branch information
Showing
36 changed files
with
1,063 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
3 changes: 3 additions & 0 deletions
3
apps/application/flow/step_node/speech_to_text_step_node/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# coding=utf-8 | ||
|
||
from .impl import * |
37 changes: 37 additions & 0 deletions
37
apps/application/flow/step_node/speech_to_text_step_node/i_speech_to_text_node.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# coding=utf-8 | ||
|
||
from typing import Type | ||
|
||
from rest_framework import serializers | ||
|
||
from application.flow.i_step_node import INode, NodeResult | ||
from common.util.field_message import ErrMessage | ||
|
||
|
||
class SpeechToTextNodeSerializer(serializers.Serializer): | ||
stt_model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id")) | ||
|
||
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容')) | ||
|
||
audio_list = serializers.ListField(required=False, error_messages=ErrMessage.list("音频")) | ||
|
||
|
||
class ISpeechToTextNode(INode): | ||
type = 'speech-to-text-node' | ||
|
||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: | ||
return SpeechToTextNodeSerializer | ||
|
||
def _run(self): | ||
res = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('audio_list')[0], | ||
self.node_params_serializer.data.get('audio_list')[1:]) | ||
for audio in res: | ||
if 'file_id' not in audio: | ||
raise ValueError("参数值错误: 上传的图片中缺少file_id,音频上传失败") | ||
|
||
return self.execute(audio=res, **self.node_params_serializer.data, **self.flow_params_serializer.data) | ||
|
||
def execute(self, stt_model_id, chat_id, | ||
audio, | ||
**kwargs) -> NodeResult: | ||
pass |
3 changes: 3 additions & 0 deletions
3
apps/application/flow/step_node/speech_to_text_step_node/impl/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# coding=utf-8 | ||
|
||
from .base_speech_to_text_node import BaseSpeechToTextNode |
58 changes: 58 additions & 0 deletions
58
apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# coding=utf-8 | ||
import os | ||
import tempfile | ||
import time | ||
import io | ||
from typing import List, Dict | ||
|
||
from django.db.models import QuerySet | ||
from pydub import AudioSegment | ||
from concurrent.futures import ThreadPoolExecutor | ||
from application.flow.i_step_node import NodeResult, INode | ||
from application.flow.step_node.speech_to_text_step_node.i_speech_to_text_node import ISpeechToTextNode | ||
from common.util.common import split_and_transcribe | ||
from dataset.models import File | ||
from setting.models_provider.tools import get_model_instance_by_model_user_id | ||
|
||
|
||
class BaseSpeechToTextNode(ISpeechToTextNode): | ||
|
||
def save_context(self, details, workflow_manage): | ||
self.context['answer'] = details.get('answer') | ||
self.answer_text = details.get('answer') | ||
|
||
def execute(self, stt_model_id, chat_id, audio, **kwargs) -> NodeResult: | ||
stt_model = get_model_instance_by_model_user_id(stt_model_id, self.flow_params_serializer.data.get('user_id')) | ||
audio_list = audio | ||
self.context['audio_list'] = audio | ||
|
||
|
||
def process_audio_item(audio_item, model): | ||
file = QuerySet(File).filter(id=audio_item['file_id']).first() | ||
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') as temp_file: | ||
temp_file.write(file.get_byte().tobytes()) | ||
temp_file_path = temp_file.name | ||
try: | ||
return split_and_transcribe(temp_file_path, model) | ||
finally: | ||
os.remove(temp_file_path) | ||
|
||
def process_audio_items(audio_list, model): | ||
with ThreadPoolExecutor(max_workers=5) as executor: | ||
results = list(executor.map(lambda item: process_audio_item(item, model), audio_list)) | ||
return '\n\n'.join(results) | ||
|
||
result = process_audio_items(audio_list, stt_model) | ||
return NodeResult({'answer': result, 'result': result}, {}) | ||
|
||
def get_details(self, index: int, **kwargs): | ||
return { | ||
'name': self.node.properties.get('stepName'), | ||
"index": index, | ||
'run_time': self.context.get('run_time'), | ||
'answer': self.context.get('answer'), | ||
'type': self.node.type, | ||
'status': self.status, | ||
'err_message': self.err_message, | ||
'audio_list': self.context.get('audio_list'), | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
3 changes: 3 additions & 0 deletions
3
apps/application/flow/step_node/text_to_speech_step_node/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# coding=utf-8 | ||
|
||
from .impl import * |
35 changes: 35 additions & 0 deletions
35
apps/application/flow/step_node/text_to_speech_step_node/i_text_to_speech_node.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# coding=utf-8 | ||
|
||
from typing import Type | ||
|
||
from rest_framework import serializers | ||
|
||
from application.flow.i_step_node import INode, NodeResult | ||
from common.util.field_message import ErrMessage | ||
|
||
|
||
class TextToSpeechNodeSerializer(serializers.Serializer): | ||
tts_model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id")) | ||
|
||
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容')) | ||
|
||
content_list = serializers.ListField(required=False, error_messages=ErrMessage.list("文本内容")) | ||
model_params_setting = serializers.DictField(required=False, | ||
error_messages=ErrMessage.integer("模型参数相关设置")) | ||
|
||
|
||
class ITextToSpeechNode(INode): | ||
type = 'text-to-speech-node' | ||
|
||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: | ||
return TextToSpeechNodeSerializer | ||
|
||
def _run(self): | ||
content = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('content_list')[0], | ||
self.node_params_serializer.data.get('content_list')[1:]) | ||
return self.execute(content=content, **self.node_params_serializer.data, **self.flow_params_serializer.data) | ||
|
||
def execute(self, tts_model_id, chat_id, | ||
content, model_params_setting=None, | ||
**kwargs) -> NodeResult: | ||
pass |
3 changes: 3 additions & 0 deletions
3
apps/application/flow/step_node/text_to_speech_step_node/impl/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# coding=utf-8 | ||
|
||
from .base_text_to_speech_node import BaseTextToSpeechNode |
73 changes: 73 additions & 0 deletions
73
apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# coding=utf-8 | ||
import io | ||
import mimetypes | ||
|
||
from django.core.files.uploadedfile import InMemoryUploadedFile | ||
|
||
from application.flow.i_step_node import NodeResult, INode | ||
from application.flow.step_node.image_understand_step_node.i_image_understand_node import IImageUnderstandNode | ||
from application.flow.step_node.text_to_speech_step_node.i_text_to_speech_node import ITextToSpeechNode | ||
from dataset.models import File | ||
from dataset.serializers.file_serializers import FileSerializer | ||
from setting.models_provider.tools import get_model_instance_by_model_user_id | ||
|
||
|
||
def bytes_to_uploaded_file(file_bytes, file_name="generated_audio.mp3"): | ||
content_type, _ = mimetypes.guess_type(file_name) | ||
if content_type is None: | ||
# 如果未能识别,设置为默认的二进制文件类型 | ||
content_type = "application/octet-stream" | ||
# 创建一个内存中的字节流对象 | ||
file_stream = io.BytesIO(file_bytes) | ||
|
||
# 获取文件大小 | ||
file_size = len(file_bytes) | ||
|
||
uploaded_file = InMemoryUploadedFile( | ||
file=file_stream, | ||
field_name=None, | ||
name=file_name, | ||
content_type=content_type, | ||
size=file_size, | ||
charset=None, | ||
) | ||
return uploaded_file | ||
|
||
|
||
class BaseTextToSpeechNode(ITextToSpeechNode): | ||
def save_context(self, details, workflow_manage): | ||
self.context['answer'] = details.get('answer') | ||
self.answer_text = details.get('answer') | ||
|
||
def execute(self, tts_model_id, chat_id, | ||
content, model_params_setting=None, | ||
**kwargs) -> NodeResult: | ||
self.context['content'] = content | ||
model = get_model_instance_by_model_user_id(tts_model_id, self.flow_params_serializer.data.get('user_id'), | ||
**model_params_setting) | ||
audio_byte = model.text_to_speech(content) | ||
# 需要把这个音频文件存储到数据库中 | ||
file_name = 'generated_audio.mp3' | ||
file = bytes_to_uploaded_file(audio_byte, file_name) | ||
application = self.workflow_manage.work_flow_post_handler.chat_info.application | ||
meta = { | ||
'debug': False if application.id else True, | ||
'chat_id': chat_id, | ||
'application_id': str(application.id) if application.id else None, | ||
} | ||
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload() | ||
# 拼接一个audio标签的src属性 | ||
audio_label = f'<audio src="{file_url}" controls style = "width: 300px; height: 43px" class ="border-r-4"/>' | ||
return NodeResult({'answer': audio_label, 'result': audio_label}, {}) | ||
|
||
def get_details(self, index: int, **kwargs): | ||
return { | ||
'name': self.node.properties.get('stepName'), | ||
"index": index, | ||
'run_time': self.context.get('run_time'), | ||
'type': self.node.type, | ||
'status': self.status, | ||
'content': self.context.get('content'), | ||
'err_message': self.err_message, | ||
'answer': self.context.get('answer'), | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.