-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
KNOOP
committed
Jan 18, 2025
0 parents
commit e95bdef
Showing
20 changed files
with
4,562 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
Large diffs are not rendered by default.
Oops, something went wrong.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from __future__ import annotations | ||
from homeassistant.config_entries import ConfigEntry | ||
from homeassistant.const import CONF_API_KEY, Platform | ||
from homeassistant.core import HomeAssistant, callback | ||
from homeassistant.exceptions import ConfigEntryNotReady | ||
from homeassistant.helpers.dispatcher import async_dispatcher_send | ||
from .const import DOMAIN, LOGGER | ||
from .intents import get_intent_handler, async_setup_intents | ||
from .services import async_setup_services | ||
from .web_search import async_setup_web_search | ||
from .image_gen import async_setup_image_gen | ||
from .entity_analysis import async_setup_entity_analysis | ||
from .process_with_ha import async_setup as async_setup_process_with_ha | ||
|
||
PLATFORMS: list[Platform] = [Platform.CONVERSATION] | ||
|
||
class ZhipuAIConfigEntry: | ||
def __init__(self, hass: HomeAssistant, config_entry: ConfigEntry): | ||
self.hass = hass | ||
self.config_entry = config_entry | ||
self.api_key = config_entry.data[CONF_API_KEY] | ||
self.options = config_entry.options | ||
self._unsub_options_update_listener = None | ||
self._cleanup_callbacks = [] | ||
self.intent_handler = get_intent_handler(hass) | ||
|
||
@property | ||
def entry_id(self): | ||
return self.config_entry.entry_id | ||
|
||
@property | ||
def title(self): | ||
return self.config_entry.title | ||
|
||
async def async_setup(self) -> None: | ||
self._unsub_options_update_listener = self.config_entry.add_update_listener( | ||
self.async_options_updated | ||
) | ||
|
||
async def async_unload(self) -> None: | ||
if self._unsub_options_update_listener is not None: | ||
self._unsub_options_update_listener() | ||
self._unsub_options_update_listener = None | ||
for cleanup_callback in self._cleanup_callbacks: | ||
cleanup_callback() | ||
self._cleanup_callbacks.clear() | ||
|
||
def async_on_unload(self, func): | ||
self._cleanup_callbacks.append(func) | ||
|
||
@callback | ||
async def async_options_updated(self, hass: HomeAssistant, entry: ConfigEntry) -> None: | ||
self.options = entry.options | ||
async_dispatcher_send(hass, f"{DOMAIN}_options_updated", entry) | ||
|
||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: | ||
try: | ||
zhipuai_entry = ZhipuAIConfigEntry(hass, entry) | ||
await zhipuai_entry.async_setup() | ||
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = zhipuai_entry | ||
|
||
unload_services = await async_setup_services(hass) | ||
zhipuai_entry.async_on_unload(unload_services) | ||
|
||
await async_setup_intents(hass) | ||
|
||
await async_setup_web_search(hass) | ||
await async_setup_image_gen(hass) | ||
await async_setup_entity_analysis(hass) | ||
await async_setup_process_with_ha(hass) | ||
|
||
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) | ||
return True | ||
|
||
except Exception as ex: | ||
raise ConfigEntryNotReady from ex | ||
|
||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: | ||
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) | ||
try: | ||
zhipuai_entry = hass.data[DOMAIN].get(entry.entry_id) | ||
if zhipuai_entry is not None and hasattr(zhipuai_entry, 'async_unload'): | ||
await zhipuai_entry.async_unload() | ||
except Exception: | ||
pass | ||
finally: | ||
hass.data[DOMAIN].pop(entry.entry_id, None) | ||
return unload_ok |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import aiohttp | ||
from aiohttp import TCPConnector | ||
from homeassistant.exceptions import HomeAssistantError | ||
from .const import LOGGER, ZHIPUAI_URL, CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT | ||
|
||
_SESSION = None | ||
|
||
async def get_session(): | ||
global _SESSION | ||
if _SESSION is None: | ||
connector = TCPConnector(ssl=False) | ||
_SESSION = aiohttp.ClientSession(connector=connector) | ||
return _SESSION | ||
|
||
async def send_ai_request(api_key: str, payload: dict, options: dict = None) -> dict: | ||
try: | ||
session = await get_session() | ||
headers = { | ||
"Authorization": f"Bearer {api_key}", | ||
"Content-Type": "application/json" | ||
} | ||
timeout = aiohttp.ClientTimeout(total=options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)) | ||
async with session.post(ZHIPUAI_URL, json=payload, headers=headers, timeout=timeout) as response: | ||
if response.status == 401: | ||
err_str = "API密钥无效或已过期" | ||
elif response.status == 429: | ||
err_str = "请求过于频繁,请稍后再试" | ||
elif response.status in [500, 502, 503]: | ||
err_str = "AI服务器暂时不可用,请稍后再试" | ||
elif response.status == 400: | ||
result = await response.json() | ||
err_str = f"请求参数错误: {result.get('error', {}).get('message', '未知错误')}" | ||
elif response.status != 200: | ||
err_str = f"AI服务返回错误 {response.status}" | ||
|
||
if response.status != 200: | ||
LOGGER.error("AI请求错误: %s", err_str) | ||
raise HomeAssistantError(err_str) | ||
|
||
result = await response.json() | ||
if "error" in result: | ||
err_msg = result["error"].get("message", "未知错误") | ||
LOGGER.error("AI返回错误: %s", err_msg) | ||
if "token" in err_msg.lower(): | ||
raise HomeAssistantError("生成的文本太长,请尝试缩短请求或减小max_tokens值") | ||
elif "rate" in err_msg.lower(): | ||
raise HomeAssistantError("请求过于频繁,请稍后再试") | ||
else: | ||
raise HomeAssistantError(f"AI服务返回错误: {err_msg}") | ||
return result | ||
|
||
except Exception as err: | ||
err_str = str(err).lower() if str(err) else "未知错误" | ||
LOGGER.error("AI通信错误: %s", err_str) | ||
|
||
if not err_str or err_str.isspace(): | ||
error_msg = "与AI服务通信失败,请检查网络连接和API密钥配置。" | ||
else: | ||
error_msg = "很抱歉,我现在无法正确处理您的请求。" + ( | ||
"网络连接失败,请检查网络设置。" if any(x in err_str for x in ["通信", "communication", "connect", "socket"]) else | ||
"请求超时,尝试减小max_tokens值或缩短请求。" if any(x in err_str for x in ["timeout", "connection", "network"]) else | ||
"API密钥无效或已过期,请更新配置。" if any(x in err_str for x in ["api key", "token", "unauthorized", "authentication"]) else | ||
"请求参数错误,请检查配置。" if "参数" in err_str or "parameter" in err_str else | ||
f"发生错误: {err_str}" | ||
) | ||
|
||
LOGGER.error("与 AI 通信时出错: %s", err_str) | ||
raise HomeAssistantError(error_msg) |
Oops, something went wrong.