Skip to content

Commit

Permalink
Add get_response route
Browse files Browse the repository at this point in the history
Implement safety checks on STT/TTS Request size
Add `user_profile` schema
  • Loading branch information
NeonDaniel committed Jan 22, 2024
1 parent 70cc490 commit 35c1a81
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 14 deletions.
4 changes: 3 additions & 1 deletion docker_overlay/etc/neon/diana.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,6 @@ hana:
server_host: "0.0.0.0"
server_port: 8080
fastapi_title: "Hana"
fastapi_summary: "HANA (HTTP API for Neon Applications) is the HTTP component of the Device Independent API for Neon Applications (DIANA)"
fastapi_summary: "HANA (HTTP API for Neon Applications) is the HTTP component of the Device Independent API for Neon Applications (DIANA)"
stt_max_length_encoded: 500000
tts_max_words: 128
4 changes: 4 additions & 0 deletions neon_hana/app/routers/assist.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,7 @@ async def get_stt(audio_in: STTRequest) -> STTResponse:
async def get_tts(request: TTSRequest) -> TTSResponse:
return mq_connector.get_tts(**dict(request))


@assist_route.post("/get_response")
async def get_response(request: SkillRequest) -> SkillResponse:
return mq_connector.get_response(**dict(request))
45 changes: 36 additions & 9 deletions neon_hana/mq_service_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from typing import Optional, Dict, Any, List
from uuid import uuid4
from fastapi import HTTPException

from neon_hana.schema.user_profile import UserProfile
from neon_mq_connector.utils.client_utils import send_mq_request


Expand All @@ -44,7 +46,8 @@ class MQServiceManager:
def __init__(self, config: dict):
self.mq_default_timeout = config.get('mq_default_timeout', 10)
self.mq_cliend_id = config.get('mq_client_id') or str(uuid4())
self.audio_tmp_path = mkdtemp("hana_audio")
self.stt_max_length = config.get('stt_max_length_encoded') or 500000
self.tts_max_words = config.get('tts_max_words') or 128

def _validate_api_proxy_response(self, response: dict):
if response['status_code'] == 200:
Expand Down Expand Up @@ -123,6 +126,10 @@ def get_coupons(self):
raise APIError(status_code=500, detail=repr(e))

def get_stt(self, encoded_audio: str, lang_code: str):
if 0 < self.stt_max_length < len(encoded_audio):
raise APIError(status_code=400,
detail=f"Audio exceeds maximum encoded length of "
f"{self.stt_max_length}")
request_data = {"msg_type": "neon.get_stt",
"data": {"audio_data": encoded_audio,
"utterances": [""], # TODO: Compat
Expand All @@ -136,18 +143,38 @@ def get_stt(self, encoded_audio: str, lang_code: str):
return response['data']

def get_tts(self, to_speak: str, lang_code: str, gender: str):
if 0 < self.tts_max_words < len(to_speak.split()):
raise APIError(status_code=400,
detail=f"Text exceeds maximum word count of "
f"{self.tts_max_words}")
request_data = {"msg_type": "neon.get_tts",
"data": {"text": to_speak,
"utterance": "", # TODO: Compat
"speaker": {"name": "Neon",
"gender": gender,
"lang": lang_code},
"data": {"text": to_speak,
"utterance": "", # TODO: Compat
"speaker": {"name": "Neon",
"gender": gender,
"lang": lang_code},
"context": {"source": "hana",
"ident": f"{self.mq_cliend_id}"
f"{time()}"}}
"lang": lang_code},
"context": {"source": "hana",
"ident": f"{self.mq_cliend_id}{time()}"}}
response = send_mq_request("/neon_chat_api", request_data,
"neon_chat_api_request",
timeout=self.mq_default_timeout)
audio = response['data'][lang_code]['audio'][gender]
return {"encoded_audio": audio}

def get_response(self, utterance: str, lang_code: str,
user_profile: UserProfile):
user_profile.user.username = (user_profile.user.username or
self.mq_cliend_id)
request_data = {"msg_type": "recognizer_loop:utterance",
"data": {"utterances": [utterance],
"lang": lang_code},
"context": {"username": user_profile.user.username,
"user_profiles": [user_profile.model_dump(mode="json")],
"source": "hana",
"ident": f"{self.mq_cliend_id}{time()}"}}
response = send_mq_request("/neon_chat_api", request_data,
"neon_chat_api_request",
timeout=self.mq_default_timeout)
sentence = response['data']['responses'][lang_code]['sentence']
return {"answer": sentence, "lang_code": lang_code}
29 changes: 25 additions & 4 deletions neon_hana/schema/assist_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from typing import Optional, List
from typing import List
from pydantic import BaseModel

from neon_hana.schema.user_profile import UserProfile


class STTRequest(BaseModel):
encoded_audio: str
Expand Down Expand Up @@ -77,11 +79,30 @@ class TTSResponse(BaseModel):

# TODO: User profile model with below inputs?

class TextInput(BaseModel):
class SkillRequest(BaseModel):
utterance: str
lang_code: str
user_profile: UserProfile

model_config = {
"json_schema_extra": {
"examples": [{
"utterance": "what time is it",
"lang_code": "en-us",
"user_profile": {"location": {"lat": 40.730610,
"lon": -73.935242,
"city": "New York",
"state": "New York"}}
}]}}

class AudioInput(BaseModel):
encoded_audio: str

class SkillResponse(BaseModel):
answer: str
lang_code: str

model_config = {
"json_schema_extra": {
"examples": [{
"answer": "four forty three.",
"lang_code": "en-us"
}]}}
104 changes: 104 additions & 0 deletions neon_hana/schema/user_profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# NEON AI (TM) SOFTWARE, Software Development Kit & Application Development System
# All trademark and other rights reserved by their respective owners
# Copyright 2008-2021 Neongecko.com Inc.
# BSD-3
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from typing import Optional, List
from pydantic import BaseModel


class ProfileUser(BaseModel):
first_name: str = ""
middle_name: str = ""
last_name: str = ""
preferred_name: str = ""
full_name: str = ""
dob: str = "YYYY/MM/DD"
age: str = ""
email: str = ""
username: str = ""
password: str = ""
picture: str = ""
about: str = ""
phone: str = ""
phone_verified: bool = False
email_verified: bool = False


class ProfileBrands(BaseModel):
ignored_brands: dict = {}
favorite_brands: dict = {}
specially_requested: dict = {}


class ProfileSpeech(BaseModel):
stt_language: str = "en-us"
alt_languages: List[str] = ['en']
tts_language: str = "en-us"
tts_gender: str = "female"
neon_voice: Optional[str] = ''
secondary_tts_language: Optional[str] = ''
secondary_tts_gender: str = "male"
secondary_neon_voice: str = ''
speed_multiplier: float = 1.0


class ProfileUnits(BaseModel):
time: int = 12
# 12, 24
date: str = "MDY"
# MDY, YMD, YDM
measure: str = "imperial"
# imperial, metric


class ProfileLocation(BaseModel):
lat: Optional[float] = None
lng: Optional[float] = None
city: Optional[str] = None
state: Optional[str] = None
country: Optional[str] = None
tz: Optional[str] = None
utc: Optional[float] = None


class ProfileResponseMode(BaseModel):
speed_mode: str = "quick"
hesitation: bool = False
limit_dialog: bool = False


class ProfilePrivacy(BaseModel):
save_audio: bool = False
save_text: bool = False


class UserProfile(BaseModel):
user: ProfileUser = ProfileUser()
# brands: ProfileBrands
speech: ProfileSpeech = ProfileSpeech()
units: ProfileUnits = ProfileUnits()
location: ProfileLocation = ProfileLocation()
response_mode: ProfileResponseMode = ProfileResponseMode()
privacy: ProfilePrivacy = ProfilePrivacy()

0 comments on commit 35c1a81

Please sign in to comment.