From a6091b065c2d7d9d507ea5b0825260cf8d63a05a Mon Sep 17 00:00:00 2001 From: KNOOP Date: Mon, 13 Jan 2025 10:36:16 +0800 Subject: [PATCH] Release v2025.01.14 --- README.md | 35 + custom_components/zhipuai/__init__.py | 86 +++ custom_components/zhipuai/ai_request.py | 68 ++ custom_components/zhipuai/config_flow.py | 414 ++++++++++++ custom_components/zhipuai/const.py | 63 ++ custom_components/zhipuai/conversation.py | 629 ++++++++++++++++++ custom_components/zhipuai/entity_analysis.py | 118 ++++ custom_components/zhipuai/image_gen.py | 109 +++ custom_components/zhipuai/intents.py | 543 +++++++++++++++ custom_components/zhipuai/intents.yaml | 352 ++++++++++ custom_components/zhipuai/manifest.json | 13 + custom_components/zhipuai/services.py | 440 ++++++++++++ custom_components/zhipuai/services.yaml | 221 ++++++ .../zhipuai/translations/en.json | 102 +++ .../zhipuai/translations/zh-Hans.json | 102 +++ custom_components/zhipuai/web_search.py | 179 +++++ hacs.json | 5 + 17 files changed, 3479 insertions(+) create mode 100644 README.md create mode 100644 custom_components/zhipuai/__init__.py create mode 100644 custom_components/zhipuai/ai_request.py create mode 100644 custom_components/zhipuai/config_flow.py create mode 100644 custom_components/zhipuai/const.py create mode 100644 custom_components/zhipuai/conversation.py create mode 100644 custom_components/zhipuai/entity_analysis.py create mode 100644 custom_components/zhipuai/image_gen.py create mode 100644 custom_components/zhipuai/intents.py create mode 100644 custom_components/zhipuai/intents.yaml create mode 100644 custom_components/zhipuai/manifest.json create mode 100644 custom_components/zhipuai/services.py create mode 100644 custom_components/zhipuai/services.yaml create mode 100644 custom_components/zhipuai/translations/en.json create mode 100644 custom_components/zhipuai/translations/zh-Hans.json create mode 100644 custom_components/zhipuai/web_search.py create mode 100644 hacs.json diff --git a/README.md b/README.md new file mode 100644 index 0000000..78ea104 --- /dev/null +++ b/README.md @@ -0,0 +1,35 @@ +# Zhipu Ai Words AI Home Assistant 🏡 + +![GitHub Version]( https://img.shields.io/github/v/release/knoop7/zhipuai ) ![GitHub Issues]( https://img.shields.io/github/issues/knoop7/zhipuai ) ![GitHub Forks]( https://img.shields.io/github/forks/knoop7/zhipuai?style=social ) ![GitHub Stars]( https://img.shields.io/github/stars/knoop7/zhipuai?style=social ) + +English | 简体中文 + +image + + +--- +## Notice: This project is strictly prohibited from commercial use without permission. You may use it as a means of profit, but it cannot be concealed. +### 📦 Installation steps +#### 1. HACS adds custom repository +In the HACS of Home Assistant, click on the three dots in the upper right corner, select "Custom Repository", and add the following URL: + +``` +https://github.com/knoop7/zhipuai +``` + + +#### 2. Add Zhipu Qingyan Integration +Go to the "Integration" page of Home Assistant, search for and add "Zhipu Qingyan". + +#### 3. Configure Key 🔑 +In the configuration page, you can log in with your phone number to obtain the Key. After obtaining it, simply fill in the Key for use without the need for additional verification. +**Attention * *: It is recommended that you create a new Key and avoid using the system's default Key. + +#### 4. Free model usage 💡 +Zhipu Qingyan has chosen the free model by default, which is completely free and there is no need to worry about charging. If you are interested, you can also choose other paid models to experience richer features. + +#### 5. Version compatibility 📅 +Please ensure that the version of Home Assistant is not lower than 11.0, as Zhipu Qingyan is mainly developed for the latest version. If encountering unrecognized entity issues, it is recommended to restart the system or update to the latest version. + +--- + diff --git a/custom_components/zhipuai/__init__.py b/custom_components/zhipuai/__init__.py new file mode 100644 index 0000000..a3494cc --- /dev/null +++ b/custom_components/zhipuai/__init__.py @@ -0,0 +1,86 @@ +from __future__ import annotations +from homeassistant.config_entries import ConfigEntry +from homeassistant.const import CONF_API_KEY, Platform +from homeassistant.core import HomeAssistant, callback +from homeassistant.exceptions import ConfigEntryNotReady +from homeassistant.helpers.dispatcher import async_dispatcher_send +from .const import DOMAIN, LOGGER +from .intents import get_intent_handler, async_setup_intents +from .services import async_setup_services +from .web_search import async_setup_web_search +from .image_gen import async_setup_image_gen +from .entity_analysis import async_setup_entity_analysis + +PLATFORMS: list[Platform] = [Platform.CONVERSATION] + +class ZhipuAIConfigEntry: + def __init__(self, hass: HomeAssistant, config_entry: ConfigEntry): + self.hass = hass + self.config_entry = config_entry + self.api_key = config_entry.data[CONF_API_KEY] + self.options = config_entry.options + self._unsub_options_update_listener = None + self._cleanup_callbacks = [] + self.intent_handler = get_intent_handler(hass) + + @property + def entry_id(self): + return self.config_entry.entry_id + + @property + def title(self): + return self.config_entry.title + + async def async_setup(self) -> None: + self._unsub_options_update_listener = self.config_entry.add_update_listener( + self.async_options_updated + ) + + async def async_unload(self) -> None: + if self._unsub_options_update_listener is not None: + self._unsub_options_update_listener() + self._unsub_options_update_listener = None + for cleanup_callback in self._cleanup_callbacks: + cleanup_callback() + self._cleanup_callbacks.clear() + + def async_on_unload(self, func): + self._cleanup_callbacks.append(func) + + @callback + async def async_options_updated(self, hass: HomeAssistant, entry: ConfigEntry) -> None: + self.options = entry.options + async_dispatcher_send(hass, f"{DOMAIN}_options_updated", entry) + +async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + try: + zhipuai_entry = ZhipuAIConfigEntry(hass, entry) + await zhipuai_entry.async_setup() + hass.data.setdefault(DOMAIN, {})[entry.entry_id] = zhipuai_entry + + unload_services = await async_setup_services(hass) + zhipuai_entry.async_on_unload(unload_services) + + await async_setup_intents(hass) + + await async_setup_web_search(hass) + await async_setup_image_gen(hass) + await async_setup_entity_analysis(hass) + + await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) + return True + + except Exception as ex: + raise ConfigEntryNotReady from ex + +async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) + try: + zhipuai_entry = hass.data[DOMAIN].get(entry.entry_id) + if zhipuai_entry is not None and hasattr(zhipuai_entry, 'async_unload'): + await zhipuai_entry.async_unload() + except Exception: + pass + finally: + hass.data[DOMAIN].pop(entry.entry_id, None) + return unload_ok \ No newline at end of file diff --git a/custom_components/zhipuai/ai_request.py b/custom_components/zhipuai/ai_request.py new file mode 100644 index 0000000..9ac6910 --- /dev/null +++ b/custom_components/zhipuai/ai_request.py @@ -0,0 +1,68 @@ +import aiohttp +from aiohttp import TCPConnector +from homeassistant.exceptions import HomeAssistantError +from .const import LOGGER, ZHIPUAI_URL, CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT + +_SESSION = None + +async def get_session(): + global _SESSION + if _SESSION is None: + connector = TCPConnector(ssl=False) + _SESSION = aiohttp.ClientSession(connector=connector) + return _SESSION + +async def send_ai_request(api_key: str, payload: dict, options: dict = None) -> dict: + try: + session = await get_session() + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } + timeout = aiohttp.ClientTimeout(total=options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)) + async with session.post(ZHIPUAI_URL, json=payload, headers=headers, timeout=timeout) as response: + if response.status == 401: + err_str = "API密钥无效或已过期" + elif response.status == 429: + err_str = "请求过于频繁,请稍后再试" + elif response.status in [500, 502, 503]: + err_str = "AI服务器暂时不可用,请稍后再试" + elif response.status == 400: + result = await response.json() + err_str = f"请求参数错误: {result.get('error', {}).get('message', '未知错误')}" + elif response.status != 200: + err_str = f"AI服务返回错误 {response.status}" + + if response.status != 200: + LOGGER.error("AI请求错误: %s", err_str) + raise HomeAssistantError(err_str) + + result = await response.json() + if "error" in result: + err_msg = result["error"].get("message", "未知错误") + LOGGER.error("AI返回错误: %s", err_msg) + if "token" in err_msg.lower(): + raise HomeAssistantError("生成的文本太长,请尝试缩短请求或减小max_tokens值") + elif "rate" in err_msg.lower(): + raise HomeAssistantError("请求过于频繁,请稍后再试") + else: + raise HomeAssistantError(f"AI服务返回错误: {err_msg}") + return result + + except Exception as err: + err_str = str(err).lower() if str(err) else "未知错误" + LOGGER.error("AI通信错误: %s", err_str) + + if not err_str or err_str.isspace(): + error_msg = "与AI服务通信失败,请检查网络连接和API密钥配置。" + else: + error_msg = "很抱歉,我现在无法正确处理您的请求。" + ( + "网络连接失败,请检查网络设置。" if any(x in err_str for x in ["通信", "communication", "connect", "socket"]) else + "请求超时,尝试减小max_tokens值或缩短请求。" if any(x in err_str for x in ["timeout", "connection", "network"]) else + "API密钥无效或已过期,请更新配置。" if any(x in err_str for x in ["api key", "token", "unauthorized", "authentication"]) else + "请求参数错误,请检查配置。" if "参数" in err_str or "parameter" in err_str else + f"发生错误: {err_str}" + ) + + LOGGER.error("与 AI 通信时出错: %s", err_str) + raise HomeAssistantError(error_msg) \ No newline at end of file diff --git a/custom_components/zhipuai/config_flow.py b/custom_components/zhipuai/config_flow.py new file mode 100644 index 0000000..4f500d4 --- /dev/null +++ b/custom_components/zhipuai/config_flow.py @@ -0,0 +1,414 @@ +from __future__ import annotations +from typing import Any +from types import MappingProxyType +import voluptuous as vol +import aiohttp +from homeassistant.core import HomeAssistant, callback +from homeassistant.config_entries import ( + ConfigEntry, + ConfigFlow, + ConfigFlowResult, + OptionsFlow, +) +from homeassistant import exceptions +from homeassistant.const import CONF_API_KEY, CONF_NAME, CONF_LLM_HASS_API +from homeassistant.helpers import llm +import homeassistant.helpers.config_validation as cv +from homeassistant.helpers.selector import ( + NumberSelector, + NumberSelectorConfig, + SelectOptionDict, + SelectSelector, + SelectSelectorConfig, + TemplateSelector, +) + +from .const import ( + CONF_PROMPT, + CONF_TEMPERATURE, + DEFAULT_NAME, + CONF_CHAT_MODEL, + CONF_MAX_TOKENS, + CONF_RECOMMENDED, + CONF_TOP_P, + CONF_MAX_HISTORY_MESSAGES, + DOMAIN, + RECOMMENDED_MAX_TOKENS, + RECOMMENDED_TEMPERATURE, + RECOMMENDED_TOP_P, + RECOMMENDED_MAX_HISTORY_MESSAGES, + CONF_MAX_TOOL_ITERATIONS, + CONF_COOLDOWN_PERIOD, + DEFAULT_MAX_TOOL_ITERATIONS, + DEFAULT_COOLDOWN_PERIOD, + CONF_WEB_SEARCH, + DEFAULT_WEB_SEARCH, + CONF_HISTORY_ANALYSIS, + CONF_HISTORY_ENTITIES, + CONF_HISTORY_DAYS, + DEFAULT_HISTORY_ANALYSIS, + DEFAULT_HISTORY_DAYS, + MAX_HISTORY_DAYS, + CONF_HISTORY_INTERVAL, + DEFAULT_HISTORY_INTERVAL, + CONF_REQUEST_TIMEOUT, + DEFAULT_REQUEST_TIMEOUT, +) + +ZHIPUAI_MODELS = [ + "GLM-4-Plus", + "GLM-4V-Plus", + "GLM-4-0520", + "GLM-4-Long", + "GLM-4-AirX", + "GLM-4-Air", + "GLM-4-FlashX", + "GLM-4-Flash", + "GLM-4V", + "GLM-4-AllTools", + "GLM-4", +] + +RECOMMENDED_CHAT_MODEL = "GLM-4-Flash" + +RECOMMENDED_OPTIONS = { + CONF_RECOMMENDED: True, + CONF_LLM_HASS_API: llm.LLM_API_ASSIST, + CONF_PROMPT: """您是 Home Assistant 的语音助手。 +如实回答有关世界的问题。 +以纯文本形式回答。保持简单明了。""", + CONF_MAX_HISTORY_MESSAGES: RECOMMENDED_MAX_HISTORY_MESSAGES, + CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL, + CONF_MAX_TOOL_ITERATIONS: DEFAULT_MAX_TOOL_ITERATIONS, + CONF_COOLDOWN_PERIOD: DEFAULT_COOLDOWN_PERIOD, +} + +class ZhipuAIConfigFlow(ConfigFlow, domain=DOMAIN): + VERSION = 1 + MINOR_VERSION = 0 + + def __init__(self) -> None: + self._reauth_entry: ConfigEntry | None = None + self._reconfigure_entry: ConfigEntry | None = None + + async def async_step_user( + self, + user_input: dict[str, Any] | None = None, + ) -> ConfigFlowResult: + errors = {} + if user_input is not None: + try: + await self._validate_api_key(user_input[CONF_API_KEY]) + return self.async_create_entry( + title=user_input[CONF_NAME], + data=user_input, + options=RECOMMENDED_OPTIONS, + ) + except UnauthorizedError: + errors["base"] = "invalid_auth" + except aiohttp.ClientError: + errors["base"] = "cannot_connect" + except ModelNotFound: + errors["base"] = "model_not_found" + except Exception: + errors["base"] = "unknown" + + return self.async_show_form( + step_id="user", + data_schema=vol.Schema({ + vol.Required(CONF_NAME, default=DEFAULT_NAME): cv.string, + vol.Required(CONF_API_KEY): cv.string, + }), + errors=errors, + ) + + async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> ConfigFlowResult: + self._reauth_entry = self._get_reauth_entry() + return await self.async_step_reauth_confirm() + + async def async_step_reauth_confirm( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + errors = {} + + if user_input is not None: + try: + await self._validate_api_key(user_input[CONF_API_KEY]) + assert self._reauth_entry is not None + return self.async_update_reload_and_abort( + self._reauth_entry, + data_updates={CONF_API_KEY: user_input[CONF_API_KEY]}, + reason="reauth_successful", + ) + except UnauthorizedError: + errors["base"] = "invalid_auth" + except aiohttp.ClientError: + errors["base"] = "cannot_connect" + except Exception: + errors["base"] = "unknown" + + return self.async_show_form( + step_id="reauth_confirm", + data_schema=vol.Schema({ + vol.Required(CONF_API_KEY): cv.string, + }), + errors=errors, + ) + + async def async_step_reconfigure(self, entry_data: Mapping[str, Any]) -> ConfigFlowResult: + self._reconfigure_entry = self._get_reconfigure_entry() + return await self.async_step_reconfigure_confirm() + + async def async_step_reconfigure_confirm( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + errors = {} + + if user_input is not None: + try: + await self._validate_api_key(user_input[CONF_API_KEY]) + assert self._reconfigure_entry is not None + return self.async_update_reload_and_abort( + self._reconfigure_entry, + data_updates={CONF_API_KEY: user_input[CONF_API_KEY]}, + reason="reconfigure_successful", + ) + except UnauthorizedError: + errors["base"] = "invalid_auth" + except aiohttp.ClientError: + errors["base"] = "cannot_connect" + except Exception: + errors["base"] = "unknown" + + return self.async_show_form( + step_id="reconfigure_confirm", + data_schema=vol.Schema({ + vol.Required(CONF_API_KEY): cv.string, + }), + errors=errors, + ) + + async def _validate_api_key(self, api_key: str) -> None: + url = "https://open.bigmodel.cn/api/paas/v4/chat/completions" + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}" + } + data = { + "model": RECOMMENDED_CHAT_MODEL, + "messages": [{"role": "user", "content": "你好"}] + } + + async with aiohttp.ClientSession() as session: + try: + async with session.post(url, headers=headers, json=data) as response: + if response.status == 200: + return + elif response.status == 401: + raise UnauthorizedError() + else: + response_json = await response.json() + error = response_json.get("error", {}) + error_message = error.get("message", "") + if "model not found" in error_message.lower(): + raise ModelNotFound() + else: + raise InvalidAPIKey() + except aiohttp.ClientError as e: + raise + + @staticmethod + @callback + def async_get_options_flow(config_entry: ConfigEntry) -> ZhipuAIOptionsFlow: + return ZhipuAIOptionsFlow(config_entry) + + +class ZhipuAIOptionsFlow(OptionsFlow): + def __init__(self, config_entry: ConfigEntry) -> None: + self._config_entry = config_entry + self._data = {} + + async def async_step_init( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + errors = {} + if user_input is not None: + try: + cooldown_period = user_input.get(CONF_COOLDOWN_PERIOD) + if cooldown_period is not None: + cooldown_period = float(cooldown_period) + if cooldown_period < 0: + errors[CONF_COOLDOWN_PERIOD] = "cooldown_too_small" + elif cooldown_period > 10: + errors[CONF_COOLDOWN_PERIOD] = "cooldown_too_large" + + if not errors: + self._data.update(user_input) + if user_input.get(CONF_HISTORY_ANALYSIS): + return await self.async_step_history() + return self.async_create_entry(title="", data=self._data) + except ValueError: + errors["base"] = "invalid_option" + + schema = vol.Schema(zhipuai_config_option_schema(self.hass, self._config_entry.options)) + return self.async_show_form( + step_id="init", + data_schema=schema, + errors=errors, + ) + + async def async_step_history( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + errors = {} + current_options = self._config_entry.options + + if user_input is not None: + try: + days = user_input.get(CONF_HISTORY_DAYS, DEFAULT_HISTORY_DAYS) + if days < 1 or days > MAX_HISTORY_DAYS: + errors[CONF_HISTORY_DAYS] = "invalid_days" + if not user_input.get(CONF_HISTORY_ENTITIES): + errors[CONF_HISTORY_ENTITIES] = "no_entities" + + if not errors: + self._data.update(user_input) + return self.async_create_entry(title="", data=self._data) + except ValueError: + errors["base"] = "invalid_option" + + entities = {} + for entity in self.hass.states.async_all(): + friendly_name = entity.attributes.get("friendly_name", entity.entity_id) + entities[entity.entity_id] = f"{friendly_name} ({entity.entity_id})" + + return self.async_show_form( + step_id="history", + data_schema=vol.Schema({ + vol.Required( + CONF_HISTORY_ENTITIES, + default=current_options.get(CONF_HISTORY_ENTITIES, []) + ): SelectSelector( + SelectSelectorConfig( + options=[{"value": k, "label": v} for k, v in entities.items()], + multiple=True, + mode="dropdown", + custom_value=False, + ) + ), + vol.Optional( + CONF_HISTORY_INTERVAL, + default=current_options.get(CONF_HISTORY_INTERVAL, DEFAULT_HISTORY_INTERVAL), + ): vol.Coerce(int), + vol.Required( + CONF_HISTORY_DAYS, + default=current_options.get(CONF_HISTORY_DAYS, DEFAULT_HISTORY_DAYS) + ): vol.All( + vol.Coerce(int), + vol.Range(min=1, max=MAX_HISTORY_DAYS), + ), + }), + errors=errors, + ) + + +def zhipuai_config_option_schema( + hass: HomeAssistant, + options: dict[str, Any] | MappingProxyType[str, Any], +) -> dict: + hass_apis = [SelectOptionDict(label="No", value="none")] + hass_apis.extend( + SelectOptionDict(label=api.name, value=api.id) + for api in llm.async_get_apis(hass) + ) + + schema = { + vol.Optional( + CONF_PROMPT, + description={"suggested_value": options.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT)}, + ): TemplateSelector(), + vol.Optional( + CONF_LLM_HASS_API, + description={"suggested_value": options.get(CONF_LLM_HASS_API)}, + default="none", + ): SelectSelector(SelectSelectorConfig(options=hass_apis)), + vol.Required( + CONF_RECOMMENDED, + default=options.get(CONF_RECOMMENDED, False) + ): bool, + vol.Optional( + CONF_MAX_HISTORY_MESSAGES, + description={"suggested_value": options.get(CONF_MAX_HISTORY_MESSAGES)}, + default=RECOMMENDED_MAX_HISTORY_MESSAGES, + ): int, + vol.Optional( + CONF_CHAT_MODEL, + description={"suggested_value": options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)}, + default=RECOMMENDED_CHAT_MODEL, + ): SelectSelector(SelectSelectorConfig(options=ZHIPUAI_MODELS)), + vol.Optional( + CONF_MAX_TOOL_ITERATIONS, + description={"suggested_value": options.get(CONF_MAX_TOOL_ITERATIONS, DEFAULT_MAX_TOOL_ITERATIONS)}, + default=DEFAULT_MAX_TOOL_ITERATIONS, + ): int, + vol.Optional( + CONF_COOLDOWN_PERIOD, + description={"suggested_value": options.get(CONF_COOLDOWN_PERIOD, DEFAULT_COOLDOWN_PERIOD)}, + default=DEFAULT_COOLDOWN_PERIOD, + ): vol.All( + vol.Coerce(float), + vol.Range(min=0, max=10), + msg="冷却时间必须在0到10秒之间" + ), + vol.Optional( + CONF_REQUEST_TIMEOUT, + description={"suggested_value": options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)}, + default=DEFAULT_REQUEST_TIMEOUT, + ): vol.All( + vol.Coerce(float), + vol.Range(min=10, max=120), + msg="超时时间必须在10到120秒之间" + ), + vol.Optional( + CONF_WEB_SEARCH, + default=DEFAULT_WEB_SEARCH, + ): bool, + vol.Optional( + CONF_HISTORY_ANALYSIS, + default=options.get(CONF_HISTORY_ANALYSIS, DEFAULT_HISTORY_ANALYSIS), + description={"suggested_value": options.get(CONF_HISTORY_ANALYSIS, DEFAULT_HISTORY_ANALYSIS)}, + ): bool, + } + + if not options.get(CONF_RECOMMENDED, False): + schema.update({ + vol.Optional( + CONF_MAX_TOKENS, + description={"suggested_value": options.get(CONF_MAX_TOKENS)}, + default=RECOMMENDED_MAX_TOKENS, + ): int, + vol.Optional( + CONF_TOP_P, + description={"suggested_value": options.get(CONF_TOP_P)}, + default=RECOMMENDED_TOP_P, + ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), + vol.Optional( + CONF_TEMPERATURE, + description={"suggested_value": options.get(CONF_TEMPERATURE)}, + default=RECOMMENDED_TEMPERATURE, + ): NumberSelector(NumberSelectorConfig(min=0, max=2, step=0.05)), + }) + + return schema + +class UnknownError(exceptions.HomeAssistantError): + pass + +class UnauthorizedError(exceptions.HomeAssistantError): + pass + +class InvalidAPIKey(exceptions.HomeAssistantError): + pass + +class ModelNotFound(exceptions.HomeAssistantError): + pass \ No newline at end of file diff --git a/custom_components/zhipuai/const.py b/custom_components/zhipuai/const.py new file mode 100644 index 0000000..3988afe --- /dev/null +++ b/custom_components/zhipuai/const.py @@ -0,0 +1,63 @@ +import logging +LOGGER = logging.getLogger(__name__) + +DOMAIN = "zhipuai" +NAME = "config.step.user.data.name" +DEFAULT_NAME = "智谱清言" +CONF_API_KEY = "api_key" +CONF_RECOMMENDED = "recommended" +CONF_PROMPT = "prompt" +CONF_CHAT_MODEL = "chat_model" +RECOMMENDED_CHAT_MODEL = "GLM-4-Flash" +CONF_MAX_TOKENS = "max_tokens" +RECOMMENDED_MAX_TOKENS = 2000 +CONF_TOP_P = "top_p" +RECOMMENDED_TOP_P = 0.7 +CONF_REQUEST_TIMEOUT = "request_timeout" +DEFAULT_REQUEST_TIMEOUT = 30 +CONF_TEMPERATURE = "temperature" +RECOMMENDED_TEMPERATURE = 0.4 +CONF_MAX_HISTORY_MESSAGES = "max_history_messages" +RECOMMENDED_MAX_HISTORY_MESSAGES = 5 + +CONF_MAX_TOOL_ITERATIONS = "max_tool_iterations" +DEFAULT_MAX_TOOL_ITERATIONS = 5 +CONF_COOLDOWN_PERIOD = "cooldown_period" +DEFAULT_COOLDOWN_PERIOD = 1 + +CONF_WEB_SEARCH = "web_search" +DEFAULT_WEB_SEARCH = True + +ZHIPUAI_URL = "https://open.bigmodel.cn/api/paas/v4/chat/completions" + +ZHIPUAI_WEB_SEARCH_URL = "https://open.bigmodel.cn/api/paas/v4/tools" +CONF_WEB_SEARCH_STREAM = "web_search_stream" +DEFAULT_WEB_SEARCH_STREAM = False + +ZHIPUAI_IMAGE_GEN_URL = "https://open.bigmodel.cn/api/paas/v4/images/generations" +CONF_IMAGE_GEN = "image_gen" +DEFAULT_IMAGE_GEN = False + +CONF_IMAGE_SIZE = "image_size" +DEFAULT_IMAGE_SIZE = "1024x1024" + +IMAGE_SIZES = [ + "1024x1024", + "768x1344", + "864x1152", + "1344x768", + "1152x864", + "1440x720", + "720x1440" +] + + +CONF_HISTORY_ANALYSIS = "history_analysis" +CONF_HISTORY_ENTITIES = "history_entities" +CONF_HISTORY_DAYS = "history_days" +CONF_HISTORY_INTERVAL = "history_interval" +DEFAULT_HISTORY_INTERVAL = 10 +DEFAULT_HISTORY_ANALYSIS = False +DEFAULT_HISTORY_DAYS = 1 +MAX_HISTORY_DAYS = 15 + diff --git a/custom_components/zhipuai/conversation.py b/custom_components/zhipuai/conversation.py new file mode 100644 index 0000000..73255d5 --- /dev/null +++ b/custom_components/zhipuai/conversation.py @@ -0,0 +1,629 @@ +from __future__ import annotations +import json +import asyncio +import time +import re +from datetime import datetime, timedelta +from typing import Any, Literal, TypedDict, Dict, List, Optional +from voluptuous_openapi import convert +from homeassistant.components import assist_pipeline, conversation +from homeassistant.components.conversation import trace +from homeassistant.config_entries import ConfigEntry +from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, MATCH_ALL, ATTR_ENTITY_ID +from homeassistant.core import HomeAssistant, Context +from homeassistant.helpers import entity_registry as er +from homeassistant.helpers import device_registry as dr, intent, llm, template, entity_registry +from homeassistant.helpers.entity_platform import AddEntitiesCallback +from homeassistant.exceptions import HomeAssistantError +from homeassistant.components.recorder import get_instance +from homeassistant.components.recorder.history import get_significant_states +from homeassistant.util import ulid +from home_assistant_intents import get_languages +from .ai_request import send_ai_request +from .intents import IntentHandler, extract_intent_info +from .const import ( + CONF_CHAT_MODEL, + CONF_MAX_TOKENS, + CONF_PROMPT, + CONF_TEMPERATURE, + CONF_TOP_P, + CONF_MAX_HISTORY_MESSAGES, + DOMAIN, + RECOMMENDED_CHAT_MODEL, + RECOMMENDED_MAX_TOKENS, + RECOMMENDED_TEMPERATURE, + RECOMMENDED_MAX_HISTORY_MESSAGES, + RECOMMENDED_TOP_P, + CONF_MAX_TOOL_ITERATIONS, + CONF_COOLDOWN_PERIOD, + DEFAULT_MAX_TOOL_ITERATIONS, + DEFAULT_COOLDOWN_PERIOD, + LOGGER, + CONF_HISTORY_ANALYSIS, + CONF_HISTORY_ENTITIES, + CONF_HISTORY_DAYS, + DEFAULT_HISTORY_DAYS, + CONF_WEB_SEARCH, + DEFAULT_WEB_SEARCH, + CONF_HISTORY_INTERVAL, + DEFAULT_HISTORY_INTERVAL, +) + + +class ChatCompletionMessageParam(TypedDict, total=False): + role: str + content: str | None + name: str | None + tool_calls: list["ChatCompletionMessageToolCallParam"] | None + +class Function(TypedDict, total=False): + name: str + arguments: str + +class ChatCompletionMessageToolCallParam(TypedDict): + id: str + type: str + function: Function + +class ChatCompletionToolParam(TypedDict): + type: str + function: dict[str, Any] + +_FILTER_PATTERNS = [ + re.compile(r'```[\s\S]*?```'), + re.compile(r'{[\s\S]*?}'), + re.compile(r'(?m)^(import|from|def|class)\s+.*$') +] + +def _format_tool(tool: llm.Tool, custom_serializer: Any | None) -> ChatCompletionToolParam: + tool_spec = { + "name": tool.name, + "parameters": convert(tool.parameters, custom_serializer=custom_serializer), + } + if tool.description: + tool_spec["description"] = tool.description + return ChatCompletionToolParam(type="function", function=tool_spec) + +def is_service_call(user_input: str) -> bool: + patterns = { + "control": ["让", "请", "帮我", "麻烦", "把", "将", "计时", "要", "想", "希望", "需要", "能否", "能不能", "可不可以", "可以", "帮忙", "给我", "替我", "为我", "我要", "我想", "我希望"], + "action": { + "turn_on": ["打开", "开启", "启动", "激活", "运行", "执行"], + "turn_off": ["关闭", "关掉", "停止"], + "toggle": ["切换"], + "press": ["按", "按下", "点击"], + "select": ["选择", "下一个", "上一个", "第一个", "最后一个"], + "trigger": ["触发", "调用"], + "media": ["暂停", "继续播放", "播放", "停止", "下一首", "下一曲", "下一个", "切歌", "换歌","上一首", "上一曲", "上一个", "返回上一首", "音量"], + "climate": ["制冷", "制热", "风速", "模式", "调温", "调到", "设置", + "空调", "冷气", "暖气", "冷风", "暖风", "自动模式", "除湿", "送风", + "高档", "低档", "高速", "低速", "自动高", "自动低", "强劲", "自动"] + } + } + + has_pattern = bool(user_input and ( + any(k in user_input for k in patterns["control"]) or + any(k in user_input for action in patterns["action"].values() for k in (action if isinstance(action, list) else [])) + )) + + has_climate = ( + "空调" in user_input and ( + bool(re.search(r"\d+", user_input)) or + any(k in user_input for k in patterns["action"]["climate"]) + ) + ) + + has_entity = bool(re.search(r"([\w_]+\.[\w_]+)", user_input)) + + return has_pattern or has_climate or has_entity + +def extract_service_info(user_input: str, hass: HomeAssistant) -> Optional[Dict[str, Any]]: + def find_entity(domain: str, text: str) -> Optional[str]: + text = text.lower() + return next((entity_id for entity_id in hass.states.async_entity_ids(domain) + if text in entity_id.split(".")[1].lower() or + text in hass.states.get(entity_id).attributes.get("friendly_name", "").lower() or + entity_id.split(".")[1].lower() in text or + hass.states.get(entity_id).attributes.get("friendly_name", "").lower() in text), None) + + def clean_text(text: str, patterns: List[str]) -> str: + control_words = ["让", "请", "帮我", "麻烦", "把", "将"] + return "".join(char for char in text if not any(word in char for word in patterns + control_words)).strip() + + if not is_service_call(user_input): + return None + + media_patterns = {"暂停": "media_pause", "继续播放": "media_play", "播放": "media_play", "停止": "media_stop", + "下一首": "media_next_track", "下一曲": "media_next_track", "下一个": "media_next_track", + "切歌": "media_next_track", "换歌": "media_next_track", "上一首": "media_previous_track", + "上一曲": "media_previous_track", "上一个": "media_previous_track", + "返回上一首": "media_previous_track", "音量": "volume_set"} + + if entity_id := find_entity("media_player", user_input): + for pattern, service in media_patterns.items(): + if pattern in user_input.lower(): + return ({"domain": "media_player", "service": service, "data": {"entity_id": entity_id, "volume_level": int(re.search(r'(\d+)', user_input).group(1)) / 100}} + if service == "volume_set" and re.search(r'(\d+)', user_input) else + {"domain": "media_player", "service": service, "data": {"entity_id": entity_id}}) + + if any(p in user_input for p in ["按", "按下", "点击"]): + return {"domain": "button", "service": "press", "data": {"entity_id": (re.search(r'(button\.\w+)', user_input).group(1) if re.search(r'(button\.\w+)', user_input) else + find_entity("button", clean_text(user_input, ["按", "按下", "点击"])))}} if (re.search(r'(button\.\w+)', user_input) or + find_entity("button", clean_text(user_input, ["按", "按下", "点击"]))) else None + + select_patterns = {"下一个": ("select_next", True), "上一个": ("select_previous", True), + "第一个": ("select_first", False), "最后一个": ("select_last", False), + "选择": ("select_option", False)} + + if entity_id := find_entity("select", user_input): + return {"domain": "select", "service": select_patterns.get(next((k for k in select_patterns.keys() if k in user_input), "选择"))[0], + "data": {"entity_id": entity_id, "cycle": select_patterns.get(next((k for k in select_patterns.keys() if k in user_input), "选择"))[1]}} if any(p in user_input for p in select_patterns.keys()) else None + + if any(p in user_input for p in ["触发", "调用", "执行", "运行", "启动"]): + name = clean_text(user_input, ["触发", "调用", "执行", "运行", "启动", "脚本", "自动化", "场景"]) + return next(({"domain": domain, "service": service, "data": {"entity_id": entity_id}} + for domain, service in [("script", "turn_on"), ("automation", "trigger"), ("scene", "turn_on")] + if (entity_id := find_entity(domain, name))), None) + + climate_patterns = {"温度": "set_temperature", "制冷": "set_hvac_mode", "制热": "set_hvac_mode", + "风速": "set_fan_mode", "模式": "set_hvac_mode", "湿度": "set_humidity", + "调温": "set_temperature", "调到": "set_temperature", "设置": "set_hvac_mode", + "度": "set_temperature"} + + if entity_id := find_entity("climate", user_input): + temperature_match = re.search(r'(\d+)(?:\s*度)?', user_input) + if temperature_match and ("温度" in user_input or "调温" in user_input or "调到" in user_input or "度" in user_input): + temperature = int(temperature_match.group(1)) + current_temp = hass.states.get(entity_id).attributes.get('current_temperature') + hvac_mode = "cool" if current_temp > temperature else "heat" if current_temp is not None else None + return {"domain": "climate", "service": "set_temperature", + "data": {"entity_id": entity_id, "temperature": temperature, "hvac_mode": hvac_mode}} if hvac_mode else {"domain": "climate", "service": "set_temperature", + "data": {"entity_id": entity_id, "temperature": temperature}} + + for pattern, service in climate_patterns.items(): + if pattern in user_input.lower(): + if service == "set_temperature": + temperature_match = re.search(r'(\d+)', user_input) + if temperature_match: + temperature = int(temperature_match.group(1)) + current_temp = hass.states.get(entity_id).attributes.get('current_temperature') + hvac_mode = "cool" if current_temp > temperature else "heat" if current_temp is not None else None + return {"domain": "climate", "service": service, + "data": {"entity_id": entity_id, "temperature": temperature, "hvac_mode": hvac_mode}} if hvac_mode else {"domain": "climate", "service": service, + "data": {"entity_id": entity_id, "temperature": temperature}} + elif service == "set_hvac_mode": + hvac_mode_map = {"制冷": "cool", "制热": "heat", "自动": "auto", "除湿": "dry", + "送风": "fan_only", "关闭": "off", "停止": "off"} + return {"domain": "climate", "service": service, + "data": {"entity_id": entity_id, "hvac_mode": next((mode for cn, mode in hvac_mode_map.items() if cn in user_input), "auto")}} + elif service == "set_fan_mode": + fan_mode_map = {"高档": "on_high", "高速": "on_high", "强劲": "on_high", + "低档": "on_low", "低速": "on_low", "自动高": "auto_high", + "自动高档": "auto_high", "自动低": "auto_low", "自动低档": "auto_low", + "关闭": "off", "停止": "off"} + return {"domain": "climate", "service": service, + "data": {"entity_id": entity_id, "fan_mode": next((mode for cn, mode in fan_mode_map.items() if cn in user_input), "auto_low")}} + elif service == "set_humidity": + humidity_match = re.search(r'(\d+)', user_input) + return {"domain": "climate", "service": service, + "data": {"entity_id": entity_id, "humidity": int(humidity_match.group(1))}} if humidity_match else None + + return None + +class ZhipuAIConversationEntity(conversation.ConversationEntity, conversation.AbstractConversationAgent): + _attr_has_entity_name = True + _attr_name = None + _attr_response = "" + + def __init__(self, entry: ConfigEntry, hass: HomeAssistant) -> None: + self.entry = entry + self.hass = hass + self.history: dict[str, list[ChatCompletionMessageParam]] = {} + self._attr_unique_id = entry.entry_id + self._attr_device_info = dr.DeviceInfo( + identifiers={(DOMAIN, entry.entry_id)}, + name=entry.title, + manufacturer="智谱清言", + model="ChatGLM Pro", + entry_type=dr.DeviceEntryType.SERVICE, + ) + if self.entry.options.get(CONF_LLM_HASS_API) and self.entry.options.get(CONF_LLM_HASS_API) != "none": + self._attr_supported_features = conversation.ConversationEntityFeature.CONTROL + self.last_request_time = 0 + self.max_tool_iterations = min(entry.options.get(CONF_MAX_TOOL_ITERATIONS, DEFAULT_MAX_TOOL_ITERATIONS), 5) + self.cooldown_period = entry.options.get(CONF_COOLDOWN_PERIOD, DEFAULT_COOLDOWN_PERIOD) + self.llm_api = None + self.intent_handler = IntentHandler(hass) + self.entity_registry = er.async_get(hass) + self.device_registry = dr.async_get(hass) + self.service_call_attempts = 0 + self._attr_native_value = "就绪" + self._attr_extra_state_attributes = {"response": ""} + + @property + def supported_languages(self) -> list[str]: + return list(dict.fromkeys(languages + ["zh-cn", "zh-tw", "zh-hk", "en"])) if (languages := get_languages()) and "zh" in languages else languages + + @property + def state_attributes(self): + attributes = super().state_attributes or {} + attributes["entity"] = "ZHIPU.AI" + if self._attr_response: + attributes["response"] = self._attr_response + return attributes + + def _filter_response_content(self, content: str) -> str: + for pattern in _FILTER_PATTERNS: + content = pattern.sub('', content) + if not content.strip(): + return "抱歉,暂不支持该操作。如果问题持续,可能需要调整指令。" + return content.strip() + + async def async_added_to_hass(self) -> None: + await super().async_added_to_hass() + assist_pipeline.async_migrate_engine(self.hass, "conversation", self.entry.entry_id, self.entity_id) + conversation.async_set_agent(self.hass, self.entry, self) + self.entry.async_on_unload(self.entry.add_update_listener(self._async_entry_update_listener)) + + async def async_will_remove_from_hass(self) -> None: + conversation.async_unset_agent(self.hass, self.entry) + await super().async_will_remove_from_hass() + + async def async_process(self, user_input: conversation.ConversationInput) -> conversation.ConversationResult: + if user_input.context and user_input.context.id and user_input.context.id.startswith(f"{DOMAIN}_service_call"): + return None + + current_time = time.time() + if current_time - self.last_request_time < self.cooldown_period: + await asyncio.sleep(self.cooldown_period - (current_time - self.last_request_time)) + self.last_request_time = time.time() + + intent_response = intent.IntentResponse(language=user_input.language) + + if is_service_call(user_input.text): + service_info = extract_service_info(user_input.text, self.hass) + if service_info: + result = await self.intent_handler.call_service( + service_info["domain"], + service_info["service"], + service_info["data"] + ) + if result["success"]: + intent_response.async_set_speech(result["message"]) + else: + intent_response.async_set_error( + intent.IntentResponseErrorCode.NO_VALID_TARGETS, + result["message"] + ) + return conversation.ConversationResult( + response=intent_response, + conversation_id=user_input.conversation_id + ) + + options = self.entry.options + tools: list[ChatCompletionToolParam] | None = None + user_name: str | None = None + llm_context = llm.LLMContext( + platform=DOMAIN, + context=user_input.context, + user_prompt=user_input.text, + language=user_input.language, + assistant=conversation.DOMAIN, + device_id=user_input.device_id, + ) + + try: + if options.get(CONF_LLM_HASS_API) and options[CONF_LLM_HASS_API] != "none": + self.llm_api = await llm.async_get_api(self.hass, options[CONF_LLM_HASS_API], llm_context) + tools = [_format_tool(tool, self.llm_api.custom_serializer) for tool in self.llm_api.tools][:8] + + if not options.get(CONF_WEB_SEARCH, DEFAULT_WEB_SEARCH): + if any(term in user_input.text.lower() for term in ["联网", "查询", "网页", "search"]): + intent_response.async_set_speech("联网搜索功能已关闭,请在配置中开启后再试。") + return conversation.ConversationResult( + response=intent_response, + conversation_id=user_input.conversation_id + ) + tools = [tool for tool in tools if tool["function"]["name"] != "web_search"] + except HomeAssistantError as err: + intent_response.async_set_error(intent.IntentResponseErrorCode.UNKNOWN, f"获取 LLM API 时出错,将继续使用基本功能。") + + intent_info = extract_intent_info(user_input.text, self.hass) + if intent_info: + result = await self.intent_handler.handle_intent(intent_info) + if result["success"]: + intent_response.async_set_speech(result["message"]) + else: + intent_response.async_set_error( + intent.IntentResponseErrorCode.NO_VALID_TARGETS, + result["message"] + ) + return conversation.ConversationResult( + response=intent_response, + conversation_id=user_input.conversation_id + ) + + if user_input.conversation_id is None: + conversation_id = ulid.ulid_now() + messages = [] + elif user_input.conversation_id in self.history: + conversation_id = user_input.conversation_id + messages = self.history[conversation_id] + else: + conversation_id = user_input.conversation_id + messages = [] + + max_history_messages = options.get(CONF_MAX_HISTORY_MESSAGES, RECOMMENDED_MAX_HISTORY_MESSAGES) + use_history = len(messages) < max_history_messages + + if user_input.context and user_input.context.user_id and (user := await self.hass.auth.async_get_user(user_input.context.user_id)): + user_name = user.name + try: + er = entity_registry.async_get(self.hass) + entities_dict = {entity_id: er.async_get(entity_id) + for entity_id in self.hass.states.async_entity_ids()} + exposed_entities = [ + entity for entity_id, entity in entities_dict.items() + if entity and not entity.hidden + ] + + prompt_parts = [ + template.Template( + llm.BASE_PROMPT + options.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT), + self.hass, + ).async_render( + { + "ha_name": self.hass.config.location_name, + "user_name": user_name, + "llm_context": llm_context, + "exposed_entities": exposed_entities, + }, + parse_result=False, + ) + ] + + if self.entry.options.get(CONF_HISTORY_ANALYSIS): + entities = self.entry.options.get(CONF_HISTORY_ENTITIES, []) + days = self.entry.options.get(CONF_HISTORY_DAYS, DEFAULT_HISTORY_DAYS) + + if entities: + try: + end_time = datetime.now() + start_time = end_time - timedelta(days=days) + + history_text = [] + history_text.append(f"\n以下是询问者所关注的实体的历史数据分析({days}天内):") + + instance = get_instance(self.hass) + history_data = await instance.async_add_executor_job( + get_significant_states, + self.hass, + start_time, + end_time, + entities, + None, + True, + True + ) + + for entity_id in entities: + state = self.hass.states.get(entity_id) + if state is None or entity_id not in history_data or not history_data[entity_id]: + history_text.append(f"\n{entity_id} (当前状态):") + history_text.append( + f"- {state.state if state else 'unknown'} ({state.last_updated.astimezone().strftime('%m-%d %H:%M:%S') if state else 'unknown'})" + ) + else: + states = history_data[entity_id] + history_text.append(f"\n{entity_id} (历史状态变化):") + last_state_text = None + last_time = None + for state in states: + if state.state == "unavailable": + continue + + current_time = state.last_updated.astimezone() + interval_minutes = self.entry.options.get(CONF_HISTORY_INTERVAL, DEFAULT_HISTORY_INTERVAL) + if last_time and (current_time - last_time).total_seconds() < interval_minutes * 60: + continue + + state_text = f"- {state.state} ({current_time.strftime('%m-%d %H:%M:%S')})" + if state_text != last_state_text: + history_text.append(state_text) + last_state_text = state_text + last_time = current_time + if len(history_text) > 1: + prompt_parts.append("\n".join(history_text)) + + except Exception as err: + LOGGER.warning(f"获取历史数据时出错: {err}") + + except template.TemplateError as err: + content_message = f"抱歉,Jinja2 模板解析出错,请检查配置模板,有实体信息配置导致获取失败: {err}" + filtered_content = self._filter_response_content(content_message) + intent_response.async_set_error(intent.IntentResponseErrorCode.UNKNOWN, filtered_content) + return conversation.ConversationResult(response=intent_response, conversation_id=conversation_id) + + if self.llm_api: + prompt_parts.append(self.llm_api.api_prompt) + + prompt = "\n".join(prompt_parts) + LOGGER.info("提示部件: %s", prompt_parts) + + messages = [ + ChatCompletionMessageParam(role="system", content=prompt), + *(messages if use_history else []), + ChatCompletionMessageParam(role="user", content=user_input.text), + ] + if len(messages) > max_history_messages + 1: + messages = [messages[0]] + messages[-(max_history_messages):] + + api_key = self.entry.data[CONF_API_KEY] + try: + for iteration in range(self.max_tool_iterations): + payload = { + "model": options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), + "messages": messages, + "max_tokens": min(options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS), 4096), + "top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P), + "temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE), + "request_id": conversation_id, + } + if tools: + payload["tools"] = tools + + + result = await send_ai_request(api_key, payload, options) + response = result["choices"][0]["message"] + + messages.append(response) + self._attr_response = response.get("content", "") + + tool_calls = response.get("tool_calls") + + if not tool_calls: + break + + tool_call_failed = False + for tool_call in tool_calls: + try: + tool_input = llm.ToolInput( + tool_name=tool_call["function"]["name"], + tool_args=json.loads(tool_call["function"]["arguments"]), + ) + + tool_response = await self._handle_tool_call(tool_input, user_input.text) + + if isinstance(tool_response, dict) and "error" in tool_response: + raise Exception(tool_response["error"]) + + formatted_response = json.dumps(tool_response) + messages.append( + ChatCompletionMessageParam( + role="tool", + tool_call_id=tool_call["id"], + content=formatted_response, + ) + ) + except Exception as e: + content_message = f"操作执行失败: {str(e)}" + messages.append( + ChatCompletionMessageParam( + role="tool", + tool_call_id=tool_call["id"], + content=content_message, + ) + ) + tool_call_failed = True + + if tool_call_failed and self.service_call_attempts >= 3: + return await self._fallback_to_hass_llm(user_input, conversation_id) + + + final_content = response.get("content", "").strip() + + if is_service_call(user_input.text): + service_info = extract_service_info(final_content, self.hass) + if service_info: + try: + await self.hass.services.async_call( + service_info["domain"], + service_info["service"], + service_info["data"], + blocking=True + ) + message = f"成功执行{service_info['domain']}:{service_info['data']['entity_id']}" + intent_response.response_text = message + return conversation.ConversationResult( + response=intent_response, + conversation_id=conversation_id + ) + except Exception as e: + intent_response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + f"执行服务失败:{str(e)}" + ) + return conversation.ConversationResult( + response=intent_response, + conversation_id=conversation_id + ) + + filtered_content = self._filter_response_content(final_content) + + self.history[conversation_id] = messages + intent_response.async_set_speech(filtered_content) + self._attr_extra_state_attributes["response"] = filtered_content + return conversation.ConversationResult(response=intent_response, conversation_id=conversation_id) + + except Exception as err: + return await self._fallback_to_hass_llm(user_input, conversation_id) + + async def _handle_tool_call(self, tool_input: llm.ToolInput, user_input: str) -> Dict[str, Any]: + try: + if self.llm_api and hasattr(self.llm_api, "async_call_tool"): + try: + result = await self.llm_api.async_call_tool(tool_input) + if isinstance(result, dict): + if "error" not in result: + return result + else: + LOGGER.warning("LLM API调用返回错误: %s", result["error"]) + return {"error": str(result["error"])} + except AttributeError as e: + LOGGER.warning("LLM API调用参数错误: %s", str(e)) + return {"error": "工具调用参数错误"} + except Exception as e: + LOGGER.warning("LLM API调用失败: %s", str(e)) + return {"error": f"工具调用失败: {str(e)}"} + + if is_service_call(user_input): + service_info = extract_service_info(user_input, self.hass) + if service_info: + result = await self.intent_handler.call_service( + service_info["domain"], + service_info["service"], + service_info["data"] + ) + return result + else: + return {"error": "无法解析服务调用信息"} + + return {"error": "无法处理该工具调用"} + + except Exception as e: + return {"error": f"处理工具调用时发生错误: {str(e)}"} + + async def _fallback_to_hass_llm(self, user_input: conversation.ConversationInput, conversation_id: str) -> conversation.ConversationResult: + try: + agent = await conversation.async_get_agent(self.hass) + result = await agent.async_process(user_input) + return result + except Exception as err: + error_msg = "很抱歉,我现在无法正确处理您的请求,请稍后再试" + + intent_response = intent.IntentResponse(language=user_input.language) + intent_response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + error_msg + ) + return conversation.ConversationResult(response=intent_response, conversation_id=conversation_id) + + @staticmethod + async def _async_entry_update_listener(hass: HomeAssistant, entry: ConfigEntry) -> None: + entity = hass.data[DOMAIN].get(entry.entry_id) + if entity: + entity.entry = entry + +async def async_setup_entry( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, +) -> None: + entity = ZhipuAIConversationEntity(config_entry, hass) + async_add_entities([entity]) + hass.data.setdefault(DOMAIN, {})[config_entry.entry_id] = entity + +async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + if unload_ok := await hass.config_entries.async_unload_platforms(entry, ["conversation"]): + hass.data[DOMAIN].pop(entry.entry_id, None) + return unload_ok \ No newline at end of file diff --git a/custom_components/zhipuai/entity_analysis.py b/custom_components/zhipuai/entity_analysis.py new file mode 100644 index 0000000..8b25539 --- /dev/null +++ b/custom_components/zhipuai/entity_analysis.py @@ -0,0 +1,118 @@ + +from __future__ import annotations + +import json +import voluptuous as vol +from datetime import datetime, timedelta +from typing import List + +from homeassistant.core import HomeAssistant, ServiceCall, State +from homeassistant.exceptions import ServiceValidationError +from homeassistant.helpers import entity_registry as er, config_validation as cv +from homeassistant.components.recorder import get_instance +from homeassistant.components.recorder.history import get_significant_states + +from .const import DOMAIN, LOGGER + +ENTITY_ANALYSIS_SCHEMA = vol.Schema({ + vol.Required("entity_id"): vol.Any(cv.entity_id, [cv.entity_id]), + vol.Optional("days", default=3): vol.All( + vol.Coerce(int), + vol.Range(min=1, max=15) + ) +}) + +async def async_setup_entity_analysis(hass: HomeAssistant) -> None: + + + async def handle_entity_analysis(call: ServiceCall) -> dict: + + try: + entity_ids = call.data["entity_id"] + if isinstance(entity_ids, str): + entity_ids = [entity_ids] + + days = call.data.get("days", 3) + + entity_registry = er.async_get(hass) + valid_entity_ids = [] + current_states = {} + + for entity_id in entity_ids: + state = hass.states.get(entity_id) + if state is None: + LOGGER.warning("实体 %s 不存在", entity_id) + continue + + if not entity_registry.async_get(entity_id): + LOGGER.info("实体 %s 未注册但存在,将只获取当前状态", entity_id) + current_states[entity_id] = state + else: + valid_entity_ids.append(entity_id) + + if not valid_entity_ids and not current_states: + error_msg = "没有找到任何有效的实体" + LOGGER.error(error_msg) + return {"success": False, "message": error_msg} + + end_time = datetime.now() + start_time = end_time - timedelta(days=days) + history_text = [] + + if valid_entity_ids: + instance = get_instance(hass) + history_data = await instance.async_add_executor_job( + get_significant_states, + hass, + start_time, + end_time, + valid_entity_ids, + None, + True, + True + ) + + if history_data: + for entity_id in valid_entity_ids: + if entity_id not in history_data: + continue + + for state in history_data[entity_id]: + if state is None: + continue + history_text.append( + f"{entity_id}, {state.state}, {state.last_updated.strftime('%Y-%m-%d %H:%M:%S')}" + ) + + for entity_id, state in current_states.items(): + history_text.append( + f"{entity_id}, {state.state}, {state.last_updated.strftime('%Y-%m-%d %H:%M:%S')}" + ) + + message = ( + f"分析时间范围: {start_time.strftime('%Y-%m-%d %H:%M:%S')} 至 " + f"{end_time.strftime('%Y-%m-%d %H:%M:%S')}\n" + f"总计 {len(history_text)} 条记录\n\n" + "实体ID, 状态值, 更新时间\n" + f"{chr(10).join(history_text)}" + ) + + LOGGER.info("生成的历史记录文本:\n%s", message) + + return { + "success": True, + "message": message + } + + except Exception as e: + error_msg = f"获取历史记录失败: {str(e)}" + LOGGER.error(error_msg) + return {"success": False, "message": error_msg} + + hass.services.async_register( + DOMAIN, + "entity_analysis", + handle_entity_analysis, + schema=ENTITY_ANALYSIS_SCHEMA, + supports_response=True + ) \ No newline at end of file diff --git a/custom_components/zhipuai/image_gen.py b/custom_components/zhipuai/image_gen.py new file mode 100644 index 0000000..a19cec3 --- /dev/null +++ b/custom_components/zhipuai/image_gen.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +import json +import os +import uuid +import aiohttp +import voluptuous as vol +from homeassistant.core import HomeAssistant, ServiceCall +from homeassistant.exceptions import ServiceValidationError + +from .const import ( + DOMAIN, + LOGGER, + ZHIPUAI_IMAGE_GEN_URL, + IMAGE_SIZES, + DEFAULT_IMAGE_SIZE, +) + +IMAGE_GEN_SCHEMA = vol.Schema({ + vol.Required("prompt"): str, + vol.Optional("model", default="cogview-3-flash"): vol.In(["cogview-3-plus", "cogview-3", "cogview-3-flash"]), + vol.Optional("size", default=DEFAULT_IMAGE_SIZE): vol.In(IMAGE_SIZES), +}) + +async def async_setup_image_gen(hass: HomeAssistant) -> None: + async def handle_image_gen(call: ServiceCall) -> None: + try: + config_entries = hass.config_entries.async_entries(DOMAIN) + if not config_entries: + raise ValueError("未找到 ZhipuAI 配置") + api_key = config_entries[0].data.get("api_key") + if not api_key: + raise ValueError("在配置中找不到 API 密钥") + + prompt = call.data["prompt"] + model = call.data.get("model", "cogview-3-flash") + size = call.data.get("size", DEFAULT_IMAGE_SIZE) + + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } + + payload = { + "model": model, + "prompt": prompt, + "size": size + } + + try: + async with aiohttp.ClientSession() as session: + async with session.post( + ZHIPUAI_IMAGE_GEN_URL, + headers=headers, + json=payload, + timeout=300 + ) as response: + response.raise_for_status() + result = await response.json() + + if not result.get("data") or not result["data"][0].get("url"): + raise ValueError("API 未返回有效的图片 URL") + + image_url = result["data"][0]["url"] + + # 下载图片并保存到本地 + www_dir = hass.config.path("www") + if not os.path.exists(www_dir): + os.makedirs(www_dir) + + filename = f"zhipuai_image_{uuid.uuid4().hex[:8]}.jpg" + filepath = os.path.join(www_dir, filename) + + async with aiohttp.ClientSession() as session: + async with session.get(image_url) as img_response: + img_response.raise_for_status() + with open(filepath, "wb") as f: + f.write(await img_response.read()) + + local_url = f"/local/{filename}" + + hass.bus.async_fire(f"{DOMAIN}_response", { + "type": "image_gen", + "content": local_url, + "success": True + }) + + return { + "success": True, + "message": local_url, + "original_url": image_url + } + + except aiohttp.ClientError as e: + LOGGER.error(f"API请求失败: {str(e)}") + raise ServiceValidationError(f"API请求失败: {str(e)}") + + except Exception as e: + error_msg = f"图像生成失败: {str(e)}" + LOGGER.error(error_msg) + return {"success": False, "message": error_msg} + + hass.services.async_register( + DOMAIN, + "image_gen", + handle_image_gen, + schema=IMAGE_GEN_SCHEMA, + supports_response=True + ) diff --git a/custom_components/zhipuai/intents.py b/custom_components/zhipuai/intents.py new file mode 100644 index 0000000..2ca1b2b --- /dev/null +++ b/custom_components/zhipuai/intents.py @@ -0,0 +1,543 @@ +from __future__ import annotations +import re +import os +import yaml +from datetime import timedelta, datetime +from typing import Any, Dict, List, Optional, Set +import voluptuous as vol +from homeassistant.components import camera +from homeassistant.components.conversation import DOMAIN as CONVERSATION_DOMAIN +from homeassistant.components.lock import LockState +from homeassistant.components.timer import ( + ATTR_DURATION, ATTR_REMAINING, + CONF_DURATION, CONF_ICON, + DOMAIN as TIMER_DOMAIN, + SERVICE_START, SERVICE_PAUSE, SERVICE_CANCEL +) +from homeassistant.const import ATTR_DEVICE_CLASS, ATTR_ENTITY_ID +from homeassistant.core import Context, HomeAssistant, ServiceResponse, State +from homeassistant.helpers import area_registry, device_registry, entity_registry, intent +from .const import DOMAIN, LOGGER + +_YAML_CACHE = {} + +async def async_load_yaml_config(hass: HomeAssistant, path: str) -> dict: + if path not in _YAML_CACHE: + if os.path.exists(path): + def _load_yaml(): + with open(path, 'r', encoding='utf-8') as f: + return yaml.safe_load(f) + _YAML_CACHE[path] = await hass.async_add_executor_job(_load_yaml) + return _YAML_CACHE.get(path, {}) + +INTENT_CAMERA_ANALYZE = "ZhipuAICameraAnalyze" +INTENT_WEB_SEARCH = "ZhipuAIWebSearch" +INTENT_TIMER = "HassTimerIntent" +INTENT_NOTIFY = "HassNotifyIntent" +INTENT_COVER_GET_STATE = "ZHIPUAI_CoverGetStateIntent" +INTENT_COVER_SET_POSITION = "ZHIPUAI_CoverSetPositionIntent" +INTENT_NEVERMIND = "nevermind" +SERVICE_PROCESS = "process" +ERROR_NO_CAMERA = "no_camera" +ERROR_NO_RESPONSE = "no_response" +ERROR_SERVICE_CALL = "service_call_error" +ERROR_NO_QUERY = "no_query" +ERROR_NO_TIMER = "no_timer" +ERROR_NO_MESSAGE = "no_message" +ERROR_INVALID_POSITION = "invalid_position" + +async def async_setup_intents(hass: HomeAssistant) -> None: + yaml_path = os.path.join(os.path.dirname(__file__), "intents.yaml") + intents_config = await async_load_yaml_config(hass, yaml_path) + if intents_config: + LOGGER.info("从 %s 加载的 intent 配置", yaml_path) + + intent.async_register(hass, CameraAnalyzeIntent(hass)) + intent.async_register(hass, WebSearchIntent(hass)) + intent.async_register(hass, HassTimerIntent(hass)) + intent.async_register(hass, HassNotifyIntent(hass)) + + +class CameraAnalyzeIntent(intent.IntentHandler): + intent_type = INTENT_CAMERA_ANALYZE + slot_schema = {vol.Required("camera_name"): str, vol.Required("question"): str} + + def __init__(self, hass: HomeAssistant): + super().__init__() + self.hass = hass + self.config = {} + self._config_loaded = False + + async def _load_config(self): + if not self._config_loaded: + yaml_path = os.path.join(os.path.dirname(__file__), "intents.yaml") + config = await async_load_yaml_config(self.hass, yaml_path) + self.config = config.get(INTENT_CAMERA_ANALYZE, {}) + self._config_loaded = True + + async def async_handle(self, intent_obj: intent.Intent) -> intent.IntentResponse: + await self._load_config() + slots = self.async_validate_slots(intent_obj.slots) + camera_name = self.get_slot_value(slots.get("camera_name")) + question = self.get_slot_value(slots.get("question")) + + LOGGER.info("Camera analyze intent info - 原始插槽: %s", slots) + + target_camera = next((e for e in intent_obj.hass.states.async_all(camera.DOMAIN) + if camera_name.lower() in (e.name.lower(), e.entity_id.lower())), None) + response = intent.IntentResponse(intent=intent_obj, language="zh-cn") + + if self.config and "speech" in self.config: + response.async_set_speech(self.config["speech"]["text"]) + + return (self._set_error_response(response, ERROR_NO_CAMERA, f"找不到名为 {camera_name} 的摄像头") + if not target_camera else await self._handle_camera_analysis(intent_obj.hass, response, target_camera, question)) + + async def _handle_camera_analysis(self, hass, response, target_camera, question) -> intent.IntentResponse: + try: + result = await hass.services.async_call(DOMAIN, "image_analyzer", + service_data={"model": "glm-4v-flash", "temperature": 0.8, "max_tokens": 1024, + "stream": False, "image_entity": target_camera.entity_id, "message": question}, + blocking=True, return_response=True) + return (self._set_speech_response(response, result.get("message", "")) + if result and isinstance(result, dict) and result.get("success", False) else + self._set_error_response(response, ERROR_SERVICE_CALL, result.get("message", "服务调用失败")) + if result and isinstance(result, dict) else + self._set_error_response(response, ERROR_NO_RESPONSE, "未能获取到有效的分析结果")) + except Exception as e: + return self._set_error_response(response, ERROR_SERVICE_CALL, f"服务调用出错:{str(e)}") + + def _set_error_response(self, response, code, message) -> intent.IntentResponse: + response.async_set_error(code=code, message=message) + return response + + def _set_speech_response(self, response, message) -> intent.IntentResponse: + response.async_set_speech(message) + return response + + def get_slot_value(self, slot_data): + return None if not slot_data else slot_data.get('value') if isinstance(slot_data, dict) else getattr(slot_data, 'value', None) if hasattr(slot_data, 'value') else str(slot_data) + + +class WebSearchIntent(intent.IntentHandler): + intent_type = INTENT_WEB_SEARCH + slot_schema = { + vol.Required("query"): str, + vol.Optional("time_query"): str, + } + + def __init__(self, hass: HomeAssistant): + super().__init__() + self.hass = hass + self.config = {} + self._config_loaded = False + + async def _load_config(self): + if not self._config_loaded: + yaml_path = os.path.join(os.path.dirname(__file__), "intents.yaml") + config = await async_load_yaml_config(self.hass, yaml_path) + self.config = config.get(INTENT_WEB_SEARCH, {}) + self._config_loaded = True + + async def async_handle(self, intent_obj: intent.Intent) -> intent.IntentResponse: + await self._load_config() + slots = self.async_validate_slots(intent_obj.slots) + query = self.get_slot_value(slots.get("query")) + time_query = self.get_slot_value(slots.get("time_query")) + + LOGGER.info("Web search info - 原始插槽:%s", slots) + + if not query: + response = intent.IntentResponse(intent=intent_obj, language="zh-cn") + return self._set_error_response(response, ERROR_NO_QUERY, "未提供搜索内容") + + if time_query: + now = datetime.now() # 简化时间处理 + if time_query in ["昨天", "昨日", "yesterday"]: + date = (now - timedelta(days=1)).strftime('%Y-%m-%d') + elif time_query in ["明天", "明日", "tomorrow"]: + date = (now + timedelta(days=1)).strftime('%Y-%m-%d') + else: + date = now.strftime('%Y-%m-%d') + query = f"{date} {query}" + + return await self._handle_web_search(intent_obj.hass, intent_obj, query) + + async def _handle_web_search( + self, hass: HomeAssistant, intent_obj: intent.Intent, query: str + ) -> intent.IntentResponse: + response = intent.IntentResponse(intent=intent_obj, language="zh-cn") + + try: + LOGGER.info("Web search service call - Query: %s", query) + result = await hass.services.async_call( + DOMAIN, + "web_search", + service_data={ + "query": query, + "stream": False + }, + blocking=True, + return_response=True + ) + + if result and isinstance(result, dict): + LOGGER.info("Web search result: %s", result) + if result.get("success", False): + return self._set_speech_response(response, result.get("message", "")) + return self._set_error_response( + response, ERROR_SERVICE_CALL, result.get("message", "搜索服务调用失败") + ) + + return self._set_error_response( + response, ERROR_NO_RESPONSE, "未能获取到有效的搜索结果" + ) + + except Exception as e: + return self._set_error_response( + response, ERROR_SERVICE_CALL, f"搜索服务调用出错:{str(e)}" + ) + + def _set_error_response(self, response, code, message) -> intent.IntentResponse: + response.async_set_error(code=code, message=message) + return response + + def _set_speech_response(self, response, message) -> intent.IntentResponse: + if len(message.encode('utf-8')) > 24 * 1024: + message = message[:24 * 1024].rsplit(' ', 1)[0] + "..." + response.async_set_speech(message) + return response + + def get_slot_value(self, slot_data): + return None if not slot_data else slot_data.get('value') if isinstance(slot_data, dict) else getattr(slot_data, 'value', None) if hasattr(slot_data, 'value') else str(slot_data) + +class HassTimerIntent(intent.IntentHandler): + intent_type = "HassTimerIntent" + slot_schema = { + vol.Required("action"): str, + vol.Optional("duration"): str, + vol.Optional("timer_name"): str + } + + def __init__(self, hass: HomeAssistant): + super().__init__() + self.hass = hass + self.config = {} + self._config_loaded = False + + async def _load_config(self): + if not self._config_loaded: + yaml_path = os.path.join(os.path.dirname(__file__), "intents.yaml") + self.config = (await async_load_yaml_config(self.hass, yaml_path)).get("HassTimerIntent", {}) + self._config_loaded = True + + async def async_handle(self, intent_obj: intent.Intent) -> intent.IntentResponse: + await self._load_config() + slots = self.async_validate_slots(intent_obj.slots) + get_slot_value = lambda slot_name: self.get_slot_value(slots.get(slot_name)) + action = get_slot_value("action") + duration = get_slot_value("duration") + timer_name = get_slot_value("timer_name") + + response = intent.IntentResponse(intent=intent_obj, language="zh-cn") + return response.async_set_error(code="no_action", message="未指定操作类型") if not action else await self._handle_timer(response, action, duration, timer_name) + + async def _handle_timer(self, response, action, duration, timer_name): + action_map = {"set": "start", "add": "start", "create": "start", "stop": "pause", "remove": "cancel", "delete": "cancel", "end": "finish", "提醒": "start"} + action = action_map.get(action, action) + LOGGER.info("Hass timer intent info - 原始插槽: %s", {"action": action, "duration": duration, "timer_name": timer_name}) + + timer_entities = self.hass.states.async_entity_ids("timer") + return response.async_set_error(code="no_timer", message="未找到可用的计时器,请先在Home Assistant中创建一个计时器") if not timer_entities else await self._process_timer(response, action, duration, timer_name, timer_entities[0]) + + async def _process_timer(self, response, action, duration, timer_name, timer_id): + time_words = {'早上': 7, '早晨': 7, '上午': 9, '中午': 12, '下午': 14, '晚上': 20, '傍晚': 18, '凌晨': 5, + '早饭': 7, '早餐': 7, '午饭': 12, '午餐': 12, '晚饭': 18, '晚餐': 18, '夜宵': 22} + + def parse_time(text): + return (None, None) if not text else self._parse_time_impl(text.lower(), time_words) + + minutes, is_absolute = parse_time(timer_name) or parse_time(duration) or (None, None) + + try: + data = {} + if minutes is not None and minutes > 0: + hours = minutes // 60 + mins = minutes % 60 + data["duration"] = f"{hours:02d}:{mins:02d}:00" + target_time = datetime.now() + timedelta(minutes=minutes) + + if is_absolute: + time_str = target_time.strftime('%H:%M') + response.async_set_speech(f"好的,已设置{timer_name if timer_name else '计时器'},将在{time_str}提醒您") + else: + time_str = (f"{hours}小时{mins}分钟后" if hours > 0 and mins > 0 else + f"{hours}小时后" if hours > 0 else + f"{mins}分钟后") + response.async_set_speech(f"好的,已设置{timer_name if timer_name else '计时器'},将在{time_str}提醒您") + else: + response.async_set_speech(f"好的,已{action}计时器") + await self.hass.services.async_call("timer", action, {"entity_id": timer_id, **data}, blocking=True) + return response + except Exception as e: + return response.async_set_error(code="service_call_error", message=f"操作失败:{str(e)}") + + def _parse_time_impl(self, text, time_words): + hour_match = re.search(r"(\d+)[点时:](\d+)?", text) + is_absolute = bool(hour_match or any(word in text for word in time_words.keys())) + + if is_absolute: + target_time = datetime.now() + timedelta(days=("明天" in text) + ("后天" in text) * 2) + hour = int(hour_match.group(1)) if hour_match else next((h for w, h in time_words.items() if w in text), None) + minute = int(hour_match.group(2)) if hour_match and hour_match.group(2) else 0 + hour = hour + 12 if hour and hour <= 12 and any(w in text for w in ['下午', '晚上', '傍晚', '晚饭', '晚餐', '夜宵']) else hour + target_time = target_time.replace(hour=hour, minute=minute) if hour is not None else target_time + minutes = int((target_time - datetime.now()).total_seconds() / 60) + return (minutes, True) if minutes > 0 else (None, True) + + matches = re.findall(r'(\d+)\s*([小时分钟天hmd]|hour|minute|min|hr|h|m)s?', text) + total_minutes = sum(int(value) * (60 if unit.startswith('h') or unit in ['小时'] else + 1 if unit.startswith('m') or unit in ['分钟'] or unit == 'm' else + 24 * 60) for value, unit in matches) + return total_minutes or None, False + + def get_slot_value(self, slot_data): + return None if not slot_data else slot_data.get('value') if isinstance(slot_data, dict) else getattr(slot_data, 'value', None) if hasattr(slot_data, 'value') else str(slot_data) + + +class HassNotifyIntent(intent.IntentHandler): + intent_type = "HassNotifyIntent" + slot_schema = {vol.Required("message"): str} + + def __init__(self, hass: HomeAssistant): + super().__init__() + self.hass = hass + + async def async_handle(self, intent_obj: intent.Intent) -> intent.IntentResponse: + response = intent.IntentResponse(intent=intent_obj, language="zh-cn") + + try: + message = self.get_slot_value(intent_obj.slots.get("message")) + if not message: + response.async_set_error("no_message", "请提供要发送的通知内容") + return response + + title_result = await self.hass.services.async_call( + "conversation", "process", + { + "agent_id": "conversation.zhi_pu_qing_yan", + "language": "zh-cn", + "text": f"请为这条消息生成一个简短的标题(不超过8个字):{message}" + }, + blocking=True, + return_response=True + ) + + title = "新通知" + if title_result and isinstance(title_result, dict): + ai_title = title_result.get("response", {}).get("speech", {}).get("plain", {}).get("speech", "") + if ai_title: + title = ai_title + + result = await self.hass.services.async_call( + "conversation", "process", + { + "agent_id": "conversation.zhi_pu_qing_yan", + "language": "zh-cn", + "text": f"请将以下内容改写成一条通知消息,只需返回改写后的文本内容,不要添加也不需要执行动作工具任何代码或格式,注意要使用表情emoji:{message}" + }, + blocking=True, + return_response=True + ) + + if not result or not isinstance(result, dict): + response.async_set_error("invalid_response", "AI 响应格式错误") + return response + + ai_response = result.get("response", {}) + ai_message = ai_response.get("speech", {}).get("plain", {}).get("speech", "") + if not ai_message: + ai_message = message + + current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + from homeassistant.components import persistent_notification + persistent_notification.async_create( + self.hass, + f"{ai_message}\n\n创建时间:{current_time}", + title=f"{title}" + ) + + response.async_set_speech(f"已创建通知: {message}") + return response + + except Exception as err: + LOGGER.exception("发送通知失败") + response.async_set_error("notification_error", f"发送通知失败: {str(err)}") + return response + + def get_slot_value(self, slot_data): + return None if not slot_data else slot_data.get('value') if isinstance(slot_data, dict) else getattr(slot_data, 'value', None) if hasattr(slot_data, 'value') else str(slot_data) + + +def extract_intent_info(user_input: str, hass: HomeAssistant) -> Optional[Dict[str, Any]]: + entity_match = re.search(r'([\w_]+\.[\w_]+)', user_input) + entity_id = entity_match.group(1) if entity_match else None + domain = entity_id.split('.')[0] if entity_id else None + + intent_mappings = { + r'打开|开启|解锁|turn on|open|unlock': 'turn_on', + r'关闭|关掉|锁定|turn off|close|lock': 'turn_off', + r'切换|toggle': 'toggle', + r'停止|暂停|stop|pause': 'stop', + r'继续|resume': 'start', + r'设置|调整|set|adjust': 'set', + } + + action = next((action for pattern, action in intent_mappings.items() + if re.search(pattern, user_input)), 'turn_on') + + action = 'lock' if domain == 'lock' and action == 'turn_on' else \ + 'unlock' if domain == 'lock' and action == 'turn_off' else action + + intent_data = {'domain': domain, 'action': action, 'data': {'entity_id': entity_id}} if entity_id else None + + + if action == 'set' and intent_data: + number_match = re.search(r'(\d+)', user_input) + value = int(number_match.group(1)) if number_match else None + if value is not None: + intent_data['data'].update({ + 'position': value if domain == 'cover' else + min(255, value * 255 // 100) if domain == 'light' else + value if domain == 'climate' else None + }) + + return intent_data + +class IntentHandler: + def __init__(self, hass: HomeAssistant): + self.hass = hass + self.area_reg = area_registry.async_get(hass) + self.device_reg = device_registry.async_get(hass) + self.entity_reg = entity_registry.async_get(hass) + + climate_modes = {"cool": "制冷模式", "heat": "制热模式", "auto": "自动模式", "dry": "除湿模式", "fan_only": "送风模式", "off": "关闭"} + fan_modes = {"on_high": "高速风", "on_low": "低速风", "auto_high": "自动高速", "auto_low": "自动低速", "off": "关闭风速"} + media_actions = {"turn_on": "打开", "turn_off": "关闭", "volume_up": "调高音量", "volume_down": "调低音量", "volume_mute": "静音", "media_play": "播放", "media_pause": "暂停", "media_stop": "停止", "media_next_track": "下一曲", "media_previous_track": "上一曲", "select_source": "切换输入源", "shuffle_set": "随机播放", "repeat_set": "循环播放", "play_media": "播放媒体"} + cover_actions = {"open_cover": "打开", "close_cover": "关闭", "stop_cover": "停止", "toggle": "切换", "set_cover_position": "设置位置", "set_cover_tilt_position": "设置角度"} + vacuum_actions = {"start": "启动", "pause": "暂停", "stop": "停止", "return_to_base": "返回充电", "clean_spot": "定点清扫", "locate": "定位", "set_fan_speed": "设置吸力"} + fan_directions = {"forward": "正向", "reverse": "反向"} + automation_actions = {"turn_on": "启用", "turn_off": "禁用", "trigger": "触发", "toggle": "切换"} + boolean_actions = {"turn_on": "打开", "turn_off": "关闭", "toggle": "切换"} + timer_actions = {"start": "启动", "pause": "暂停", "cancel": "取消", "finish": "结束", "reload": "重新加载"} + + async def call_service(self, domain: str, service: str, data: Dict[str, Any]) -> Dict[str, Any]: + try: + entity_id = data.get("entity_id") + entity = self.hass.states.get(entity_id) + friendly_name = entity.attributes.get("friendly_name") if entity else "设备" + await self.hass.services.async_call(domain, service, {**data, "entity_id": entity_id} if entity_id else data, blocking=True) + target_temp = data.get('temperature') + current_temp = entity.attributes.get('current_temperature') if entity and domain == "climate" and service == "set_temperature" else None + _ = await self.hass.services.async_call(domain, "set_hvac_mode", {"entity_id": entity_id, "hvac_mode": "cool" if current_temp > target_temp else "heat"}, blocking=True) if current_temp is not None and target_temp is not None else None + return ({"success": True, "message": f"已设置 {friendly_name} {'制冷' if current_temp > target_temp else '制热'}模式,温度{target_temp}度"} if domain == "climate" and service == "set_temperature" and current_temp is not None and target_temp is not None else + {"success": True, "message": f"已设置 {friendly_name} 温度{target_temp}度"} if domain == "climate" and service == "set_temperature" and target_temp is not None else + {"success": True, "message": f"已执行 {friendly_name} {self.climate_modes.get(data.get('hvac_mode'), data.get('hvac_mode'))}"} if domain == "climate" and service == "set_hvac_mode" else + {"success": True, "message": f"已执行 {friendly_name} {self.fan_modes.get(data.get('fan_mode'), data.get('fan_mode'))}"} if domain == "climate" and service == "set_fan_mode" else + {"success": True, "message": f"已执行 {friendly_name} 湿度{data.get('humidity')}%"} if domain == "climate" and service == "set_humidity" else + {"success": True, "message": f"已执行 {friendly_name} 亮度{int(data['brightness'] * 100 / 255)}%"} if domain == "light" and service == "turn_on" and "brightness" in data else + {"success": True, "message": f"已设置 {friendly_name} {'颜色' if 'rgb_color' in data else '色温'}"} if domain == "light" and service == "turn_on" and ("rgb_color" in data or "color_temp" in data) else + {"success": True, "message": f"已{'打开' if service == 'turn_on' else '关闭'} {friendly_name}"} if domain == "light" and service in ["turn_on", "turn_off"] else + {"success": True, "message": f"已{self.media_actions.get(service, service)} {friendly_name}"} if domain == "media_player" else + {"success": True, "message": f"已设置 {friendly_name} {'位置到' + str(data['position']) + '%' if 'position' in data else '角度到' + str(data['tilt_position']) + '%'}"} if domain == "cover" and ("position" in data or "tilt_position" in data) else + {"success": True, "message": f"已{self.cover_actions.get(service, service)} {friendly_name}"} if domain == "cover" else + {"success": True, "message": f"已{'打开' if service == 'turn_on' else '关闭'} {friendly_name}"} if domain == "switch" and service in ["turn_on", "turn_off"] else + {"success": True, "message": f"已设置 {friendly_name} {'风速' + str(data['percentage']) + '%' if 'percentage' in data else data['preset_mode'] + '模式'}"} if domain == "fan" and service == "turn_on" and ("percentage" in data or "preset_mode" in data) else + {"success": True, "message": f"已{'打开' if service == 'turn_on' else '关闭'} {friendly_name}"} if domain == "fan" and service in ["turn_on", "turn_off"] else + {"success": True, "message": f"已{'开启' if data.get('oscillating') else '关闭'} {friendly_name} 摆风"} if domain == "fan" and service == "oscillate" else + {"success": True, "message": f"已设置 {friendly_name} {self.fan_directions.get(data.get('direction'), data.get('direction'))}旋转"} if domain == "fan" and service == "set_direction" else + {"success": True, "message": f"已启动场景 {friendly_name}"} if domain == "scene" and service == "turn_on" else + {"success": True, "message": f"已执行脚本 {friendly_name}"} if domain == "script" and service in ["turn_on", "start"] else + {"success": True, "message": f"已{self.automation_actions.get(service, service)}自动化 {friendly_name}"} if domain == "automation" else + {"success": True, "message": f"已{self.boolean_actions.get(service, service)} {friendly_name}"} if domain == "input_boolean" else + {"success": True, "message": f"已{self.timer_actions.get(service, service)}计时器 {friendly_name}"} if domain == "timer" else + {"success": True, "message": f"已设置 {friendly_name} 吸力为{data['fan_speed']}"} if domain == "vacuum" and service == "set_fan_speed" and "fan_speed" in data else + {"success": True, "message": f"已{self.vacuum_actions.get(service, service)} {friendly_name}"} if domain == "vacuum" else + {"success": True, "message": f"已按下 {friendly_name}"} if domain == "button" and service == "press" else + {"success": True, "message": f"已执行 {friendly_name} {service}"}) + except Exception as e: + return {"success": False, "message": str(e)} + + async def handle_intent(self, intent_info: Dict[str, Any]) -> Dict[str, Any]: + domain = intent_info.get('domain') + action = intent_info.get('action', '') + data = intent_info.get('data', {}) + name = data['name'].get('value') if isinstance(data.get('name'), dict) else data.get('name') + area = data['area'].get('value') if isinstance(data.get('area'), dict) else data.get('area') + entity_id = data.get('entity_id') + + return await ( + self.handle_nevermind_intent() if domain == CONVERSATION_DOMAIN and action == INTENT_NEVERMIND else + self.call_service(domain, action, data) if entity_id and await self._validate_service_for_entity(domain, action, entity_id) else + self.handle_cover_intent(action, name, area, data) if domain == "cover" else + self.handle_lock_intent(action, name, area, data) if domain == "lock" else + self.handle_timer_intent(action, name, area, data) if domain == TIMER_DOMAIN else + self.call_service(domain, action, data) + ) + + async def handle_nevermind_intent(self) -> Dict[str, Any]: + try: + await self.hass.services.async_call(CONVERSATION_DOMAIN, SERVICE_PROCESS, {"text": "再见"}, blocking=True) + return {"success": True, "message": "再见!", "close_conversation": True} + except Exception as e: + return {"success": False, "message": str(e)} + + async def _validate_service_for_entity(self, domain: str, service: str, entity_id: str) -> bool: + state = self.hass.states.get(entity_id) + return bool( + state and entity_id in self.hass.states.async_entity_ids() and + (supported_features := state.attributes.get("supported_features", 0)) and + not ( + (domain == "fan" and service == "set_percentage" and not (supported_features & 1)) or + (domain == "cover" and ((service == "close" and not (supported_features & 2)) or + (service == "set_position" and not (supported_features & 4)))) or + (domain == "lock" and ((service == "unlock" and not (supported_features & 1)) or + (service == "lock" and not (supported_features & 2)))) + ) + ) + async def handle_cover_intent(self, action: str, name: str, area: str, data: Dict[str, Any]) -> Dict[str, Any]: + entity_id = data.get('entity_id') + return ( + {"success": False, "message": "未指定窗帘实体"} if not entity_id else + await self.call_service("cover", "set_cover_position", + {"entity_id": entity_id, "position": data.get('position')}) + if action == "set" and data.get('position') is not None else + await self.call_service("cover", action, {"entity_id": entity_id}) + ) + + async def handle_lock_intent(self, action: str, name: str, area: str, data: Dict[str, Any]) -> Dict[str, Any]: + entity_id = data.get('entity_id') + return ( + {"success": False, "message": "未指定门锁实体"} if not entity_id else + await self.call_service("lock", action, {"entity_id": entity_id}) + ) + + async def handle_timer_intent(self, action: str, name: str, area: str, data: Dict[str, Any]) -> Dict[str, Any]: + duration = data.get("duration", "") + timer_id = data.get('entity_id') + + minutes = ( + int(''.join(filter(str.isdigit, duration))) if "minutes" in duration else + int(minutes_match.group(1)) if (minutes_match := re.search(r'(\d+)\s*分钟', duration)) else + None + ) + + return { + "action": f"timer.{action}", + "data": {"duration": f"00:{minutes:02d}:00"} if action == "start" and minutes is not None else {}, + "target": {"entity_id": timer_id} + } + +def get_intent_handler(hass: HomeAssistant) -> IntentHandler: + return IntentHandler(hass) \ No newline at end of file diff --git a/custom_components/zhipuai/intents.yaml b/custom_components/zhipuai/intents.yaml new file mode 100644 index 0000000..53f7bfe --- /dev/null +++ b/custom_components/zhipuai/intents.yaml @@ -0,0 +1,352 @@ +language: "zh-cn" +intents: + HassTurnOn: + data: + - sentences: + - "打开(?:[|])?([^|]+)" + - "开启(?:[|])?([^|]+)" + - "turn on(?:[|])?([^|]+)" + slots: + name: + type: text + device_class: + type: list + values: + - awning + - blind + - curtain + - damper + - door + - garage + - gas + - gate + - outlet + - receiver + - shade + - shutter + - speaker + - switch + - tv + - water + - window + + HassTurnOff: + data: + - sentences: + - "关闭{name}" + - "关掉{name}" + - "turn off {name}" + slots: + name: + type: text + device_class: + type: list + values: + - awning + - blind + - curtain + - damper + - door + - garage + - gas + - gate + - outlet + - receiver + - shade + - shutter + - speaker + - switch + - tv + - water + - window + + MassPlayMediaAssist: + data: + - sentences: + - "在(歌手|艺术家){artist}的(歌|音乐)" + - "在{album}" + - "在{track}" + - "在{playlist}" + - "在{radio}" + expansion_rules: + play: "(play|播放|收听)" + track: "(track|歌曲|音乐)" + album: "(album|专辑|唱片|合集|单曲)" + playlist: "(playlist|播放列表)" + radio_station: "(radio_station|广播电台|电台|频道)" + + - sentences: + - "用{name}(歌手|艺术家){artist}的(歌|音乐)" + - "用{name}{album}" + - "用{name}{track}" + - "用{name}{playlist}" + - "用{name}{radio}" + expansion_rules: + play: "(play|播放|收听)" + track: "(track|歌曲|音乐)" + album: "(album|专辑|唱片|合集|单曲)" + playlist: "(playlist|播放列表)" + radio_station: "(radio_station|广播电台|电台|频道)" + requires_context: + domain: "media_player" + + - sentences: + - "用在[里]的{name}(歌手|艺术家){artist}的(歌|音乐)" + - "用在[里]的{name}{album}" + - "用在[里]的{name}{track}" + - "用在[里]的{name}{playlist}" + - "用在[里]的{name}{radio}" + expansion_rules: + play: "(play|播放|收听)" + track: "(track|歌曲|音乐)" + album: "(album|专辑|唱片|合集|单曲)" + playlist: "(playlist|播放列表)" + radio_station: "(radio_station|广播电台|电台|频道)" + + - sentences: + - "(歌手|艺术家){artist}的(歌|音乐)" + - "{album}" + - "{track}" + - "{playlist}" + - "{radio}" + expansion_rules: + play: "(play|播放|收听)" + track: "(track|歌曲|音乐)" + album: "(album|专辑|唱片|合集|单曲)" + playlist: "(playlist|播放列表)" + radio_station: "(radio_station|广播电台|电台|频道)" + requires_context: + area: + slot: true + + ZHIPUAI_CoverGetStateIntent: + data: + - sentences: + - "[]{name}(是|是不是){cover_states:state}[吗|不]" + response: one_yesno + requires_context: + domain: cover + slots: + domain: cover + + - sentences: + - "[{area}][有|有没有]{cover_classes:device_class}[是|是不是]{cover_states:state}[吗|不]" + response: any + slots: + domain: cover + + - sentences: + - "[][]{cover_classes:device_class}[是|是不是]都[是]{cover_states:state}[吗|不]" + - "{cover_classes:device_class}[是|是不是]都[是]{cover_states:state}[吗|不]" + response: all + slots: + domain: cover + + - sentences: + - "[]{cover_classes:device_class}[是]{cover_states:state}" + - "[]{cover_classes:device_class}[是]{cover_states:state}" + response: which + slots: + domain: cover + + - sentences: + - "[{area}]{cover_classes:device_class}[是]{cover_states:state}" + - "[]{cover_classes:device_class}[是]{cover_states:state}" + response: how_many + slots: + domain: cover + + ZHIPUAI_CoverSetPositionIntent: + data: + - sentences: + - "(||) [position] " + requires_context: + domain: cover + slots: + domain: cover + + - sentences: + - "(||) {cover_classes:device_class}[ position] " + slots: + domain: cover + + - sentences: + - "[] (||) [position] " + requires_context: + domain: cover + slots: + domain: cover + + - sentences: + - "[]{cover_classes:device_class}(||)[ position] " + slots: + domain: cover + + HassGetState: + data: + - sentences: + - "[查询|查看][]{name}[的]状态" + - "[]{name}[现在|当前]是什么状态" + - "[]{name}[怎么样|如何]" + slots: + domain: all + + - sentences: + - "[]{name}[的锁](是不是|有没有){lock_states:state}" + - "[]{name}[的锁][是|有]{lock_states:state}[吗|不]" + response: one_yesno + requires_context: + domain: lock + slots: + domain: lock + + ZhipuAIWebSearch: + data: + - sentences: + - "联网{time_query}{query}" + - "联网{time_query}的{query}" + - "帮我联网{time_query}{query}" + - "帮我上网查{time_query}{query}" + - "联网搜索{time_query}{query}" + - "联网查找{time_query}{query}" + - "联网查询{time_query}{query}" + - "在网上搜索{time_query}{query}" + - "在网上查找{time_query}{query}" + - "在网上查询{time_query}{query}" + - "上网搜索{time_query}{query}" + - "上网查找{time_query}{query}" + - "上网查询{time_query}{query}" + - "网上搜索{time_query}{query}" + - "网上查找{time_query}{query}" + - "网上查询{time_query}{query}" + - "互联网搜索{time_query}{query}" + - "互联网查找{time_query}{query}" + - "互联网查询{time_query}{query}" + - "百度{time_query}{query}" + - "谷歌{time_query}{query}" + - "必应{time_query}{query}" + - "搜一下{time_query}{query}" + - "查一下{time_query}{query}" + - "帮我找找{time_query}{query}" + - "帮忙查找{time_query}{query}" + - "帮忙查询{time_query}{query}" + - "search for {time_query}{query}" + - "search {time_query}{query}" + - "find {time_query}{query}" + - "look up {time_query}{query}" + - "google {time_query}{query}" + - "bing {time_query}{query}" + slots: + query: + type: text + example: "中国队奥运会奖牌数" + time_query: + type: list + values: + - "今天" + - "今日" + - "昨天" + - "昨日" + - "明天" + - "明日" + - "当前" + - "现在" + + CameraAnalyzeIntent: + data: + - sentences: + - "查看{camera_name}的{question}" + - "看看{camera_name}的{question}" + - "告诉我{camera_name}的{question}" + - "{camera_name}现在{question}" + - "分析{camera_name}的{question}" + - "查看{camera_name}的{question}" + - "检查{camera_name}的{question}" + - "观察{camera_name}的{question}" + - "识别{camera_name}的{question}" + - "告诉我{camera_name}的{question}" + - "帮我看看{camera_name}的{question}" + - "帮我分析{camera_name}的{question}" + - "帮我检查{camera_name}的{question}" + - "帮我识别{camera_name}的{question}" + - "{camera_name}那里的{question}" + - "{camera_name}现在的{question}" + - "{camera_name}目前的{question}" + - "看一下{camera_name}的{question}" + - "分析一下{camera_name}的{question}" + - "检查一下{camera_name}的{question}" + - "观察一下{camera_name}的{question}" + - "识别一下{camera_name}的{question}" + - "分析下{camera_name}{question}" + - "查看下{camera_name}{question}" + - "看看{camera_name}{question}" + - "analyze {camera_name} {question}" + - "check {camera_name} {question}" + - "look at {camera_name} {question}" + - "identify {camera_name} {question}" + slots: + camera_name: + type: text + question: + type: text + speech: + text: "正在分析摄像头画面,请稍等..." + + HassTimerIntent: + action: + - timer + speech: + text: 设置计时器 + slots: + action: + type: text + duration: + type: text + timer_name: + type: text + templates: + - "[设置|创建|添加|开始|启动]{duration}[提醒|记录|通知]{timer_name}" + - "[设置|创建|添加|开始|启动][提醒|记录|通知]{duration}{timer_name}" + - "{duration}[提醒|记录|通知]{timer_name}" + - "[设置|创建|添加|开始|启动][提醒|记录|通知]{timer_name}{duration}" + - "[设置|创建|添加|开始|启动]一个{duration}的[计时器|定时器]{timer_name}" + - "[设置|创建|添加|开始|启动]{duration}的[计时器|定时器]{timer_name}" + - "在{duration}[后|以后|之后][设置|创建|添加|开始|启动][计时器|定时器]{timer_name}" + - "{duration}[后|以后|之后][设置|创建|添加|开始|启动][计时器|定时器]{timer_name}" + - "[取消|停止|关闭|结束|删除][计时器|定时器|提醒|记录|通知]{timer_name}" + + HassNotifyIntent: + action: + - notify + speech: + text: 发送通知 + slots: + message: + type: text + templates: + - "[记住|记录|记一下|提醒我|提醒|提醒一下]{message}" + - "帮我[记住|记录|记一下]{message}" + - "帮我[提醒|提醒一下]{message}" + - "[记住|记录|记一下|提醒我|提醒|提醒一下][这个|这件事]{message}" + - "帮我[记住|记录|记一下][这个|这件事]{message}" + - "帮我[提醒|提醒一下][这个|这件事]{message}" + +responses: + intents: + MassPlayMediaAssist: + default: "ok" + +lists: + artist: + wildcard: true + album: + wildcard: true + track: + wildcard: true + playlist: + wildcard: true + radio: + wildcard: true + radio_mode: + values: + - "radio mode" \ No newline at end of file diff --git a/custom_components/zhipuai/manifest.json b/custom_components/zhipuai/manifest.json new file mode 100644 index 0000000..42bb002 --- /dev/null +++ b/custom_components/zhipuai/manifest.json @@ -0,0 +1,13 @@ +{ + "domain": "zhipuai", + "name": "\u667a\u8c31\u6e05\u8a00", + "codeowners": ["@knoop7"], + "config_flow": true, + "documentation": "https://github.com/knoop7/zhipuai", + "integration_type": "hub", + "iot_class": "cloud_polling", + "issue_tracker": "https://github.com/knoop7/zhipuai/issues", + "requirements": [], + "version": "2024.01.13", + "translations": ["en", "zh-Hans"] +} diff --git a/custom_components/zhipuai/services.py b/custom_components/zhipuai/services.py new file mode 100644 index 0000000..792fecd --- /dev/null +++ b/custom_components/zhipuai/services.py @@ -0,0 +1,440 @@ +from __future__ import annotations + +import base64 +import io +import os +import os.path +import numpy as np +from PIL import Image +import voluptuous as vol +from homeassistant.core import HomeAssistant, ServiceCall, callback +from homeassistant.helpers import config_validation as cv +from homeassistant.components import camera +from homeassistant.components.camera import Image as CameraImage +from homeassistant.exceptions import ServiceValidationError +import requests +import json +import time +import asyncio + +from .const import ( + DOMAIN, + LOGGER, + ZHIPUAI_URL, + CONF_TEMPERATURE, + RECOMMENDED_TEMPERATURE, + CONF_MAX_TOKENS, + RECOMMENDED_MAX_TOKENS, +) + +class ImageProcessor: + def __init__(self, hass: HomeAssistant): + self.hass = hass + + async def resize_image(self, image_data: bytes, target_width: int = 1280, is_gif: bool = False) -> str: + try: + img_byte_arr = io.BytesIO(image_data) + img = await self.hass.async_add_executor_job(Image.open, img_byte_arr) + + if is_gif and img.format == 'GIF' and getattr(img, "is_animated", False): + frames = [] + try: + for frame in range(img.n_frames): + img.seek(frame) + new_frame = await self.hass.async_add_executor_job(lambda x: x.convert('RGB'), img.copy()) + + width, height = new_frame.size + aspect_ratio = width / height + target_height = int(target_width / aspect_ratio) + if width > target_width or height > target_height: + new_frame = await self.hass.async_add_executor_job(lambda x: x.resize((target_width, target_height)), new_frame) + frames.append(new_frame) + + output = io.BytesIO() + frames[0].save(output, save_all=True, append_images=frames[1:], format='GIF', duration=img.info.get('duration', 100), loop=0) + base64_image = base64.b64encode(output.getvalue()).decode('utf-8') + return base64_image + except Exception as e: + LOGGER.error(f"GIF处理错误: {str(e)}") + img.seek(0) + + if img.mode == 'RGBA' or img.format == 'GIF': + img = await self.hass.async_add_executor_job(lambda x: x.convert('RGB'), img) + + width, height = img.size + aspect_ratio = width / height + target_height = int(target_width / aspect_ratio) + + if width > target_width or height > target_height: + img = await self.hass.async_add_executor_job(lambda x: x.resize((target_width, target_height)), img) + + img_byte_arr = io.BytesIO() + await self.hass.async_add_executor_job(lambda i, b: i.save(b, format='JPEG'), img, img_byte_arr) + base64_image = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8') + + return base64_image + except Exception as e: + LOGGER.error(f"图像处理错误: {str(e)}") + raise ServiceValidationError(f"图像处理失败: {str(e)}") + + def _similarity_score(self, previous_frame, current_frame_gray): + K1 = 0.005 + K2 = 0.015 + L = 255 + + C1 = (K1 * L) ** 2 + C2 = (K2 * L) ** 2 + + previous_frame_np = np.array(previous_frame, dtype=np.float64) + current_frame_np = np.array(current_frame_gray, dtype=np.float64) + + mu1 = np.mean(previous_frame_np) + mu2 = np.mean(current_frame_np) + + sigma1_sq = np.var(previous_frame_np) + sigma2_sq = np.var(current_frame_np) + sigma12 = np.cov(previous_frame_np.flatten(), current_frame_np.flatten())[0, 1] + + ssim = ((2 * mu1 * mu2 + C1) * (2 * sigma12 + C2)) / ((mu1**2 + mu2**2 + C1) * (sigma1_sq + sigma2_sq + C2)) + + return ssim + +IMAGE_ANALYZER_SCHEMA = vol.Schema({ + vol.Required("model", default="glm-4v-flash"): vol.In(["glm-4v-plus", "glm-4v", "glm-4v-flash"]), + vol.Required("message"): cv.string, + vol.Optional("image_file"): cv.string, + vol.Optional("image_entity"): cv.entity_id, + vol.Optional(CONF_TEMPERATURE, default=RECOMMENDED_TEMPERATURE): vol.All(vol.Coerce(float), vol.Range(min=0.1, max=1.0)), + vol.Optional(CONF_MAX_TOKENS, default=RECOMMENDED_MAX_TOKENS): vol.All(vol.Coerce(int), vol.Range(min=1, max=1024)), + vol.Optional("stream", default=False): cv.boolean, + vol.Optional("target_width", default=1280): vol.All(vol.Coerce(int), vol.Range(min=512, max=1920)), +}) + +VIDEO_ANALYZER_SCHEMA = vol.Schema( + { + vol.Required("model", default="glm-4v-plus"): cv.string, + vol.Required("message"): cv.string, + vol.Required("video_file"): cv.string, + vol.Optional(CONF_TEMPERATURE, default=RECOMMENDED_TEMPERATURE): vol.All( + vol.Coerce(float), vol.Range(min=0.1, max=1.0) + ), + vol.Optional(CONF_MAX_TOKENS, default=RECOMMENDED_MAX_TOKENS): vol.All( + vol.Coerce(int), vol.Range(min=1, max=1024) + ), + vol.Optional("stream", default=False): cv.boolean, + vol.Optional("target_width", default=1280): vol.All( + vol.Coerce(int), vol.Range(min=512, max=1920) + ), + } +) + +async def async_setup_services(hass: HomeAssistant) -> None: + image_processor = ImageProcessor(hass) + + async def handle_image_analyzer(call: ServiceCall) -> None: + try: + config_entries = hass.config_entries.async_entries(DOMAIN) + if not config_entries: + raise ValueError("未找到 ZhipuAI") + api_key = config_entries[0].data.get("api_key") + if not api_key: + raise ValueError("在配置中找不到 API 密钥") + + image_data = None + if image_file := call.data.get("image_file"): + try: + if not os.path.isabs(image_file): + image_file = os.path.join(hass.config.config_dir, image_file) + + if os.path.isdir(image_file): + raise ServiceValidationError(f"提供的路径是一个目录: {image_file}") + + if not os.path.exists(image_file): + raise ServiceValidationError(f"图片文件不存在: {image_file}") + + with open(image_file, "rb") as f: + image_data = f.read() + except IOError as e: + LOGGER.error(f"读取图片文件失败 {image_file}: {str(e)}") + raise ServiceValidationError(f"读取图片文件失败: {str(e)}") + + elif image_entity := call.data.get("image_entity"): + try: + if not image_entity.startswith("camera."): + raise ServiceValidationError(f"无效的摄像头实体ID: {image_entity}") + + if not hass.states.get(image_entity): + raise ServiceValidationError(f"摄像头实体不存在: {image_entity}") + + try: + image: CameraImage = await camera.async_get_image(hass, image_entity, timeout=10) + + if not image or not image.content: + raise ServiceValidationError(f"无法从摄像头获取图片: {image_entity}") + + image_data = image.content + base64_image = await image_processor.resize_image(image_data, target_width=call.data.get("target_width", 1280), is_gif=True) + + except (camera.CameraEntityImageError, TimeoutError) as e: + raise ServiceValidationError(f"获取摄像头图片失败: {str(e)}") + + except Exception as e: + LOGGER.error(f"从实体获取图片失败 {image_entity}: {str(e)}") + raise ServiceValidationError(f"从实体获取图片失败: {str(e)}") + + if not image_data: + raise ServiceValidationError("未提供图片数据") + + try: + base64_image = await image_processor.resize_image(image_data, target_width=call.data.get("target_width", 1280)) + except Exception as e: + LOGGER.error(f"图像处理失败: {str(e)}") + raise ServiceValidationError(f"图像处理失败: {str(e)}") + + headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {api_key}' + } + payload = { + "model": call.data["model"], + "messages": [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}" + } + }, + { + "type": "text", + "text": call.data["message"] + } + ] + } + ], + "stream": call.data.get("stream", False), + "temperature": call.data.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE), + "max_tokens": call.data.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS), + } + + try: + response = await hass.async_add_executor_job(lambda: requests.post(ZHIPUAI_URL, headers=headers, json=payload, stream=call.data.get("stream", False), timeout=30)) + response.raise_for_status() + except requests.exceptions.RequestException as e: + LOGGER.error(f"API请求失败: {str(e)}") + raise ServiceValidationError(f"API请求失败: {str(e)}") + + if call.data.get("stream", False): + event_id = f"zhipuai_response_{int(time.time())}" + + try: + hass.bus.async_fire(f"{DOMAIN}_stream_start", {"event_id": event_id, "type": "image_analysis"}) + + accumulated_text = "" + for line in response.iter_lines(): + if line: + try: + data = line.decode('utf-8').replace('data: ', '') + if data == '[DONE]': + break + + json_data = json.loads(data) + if 'choices' in json_data and len(json_data['choices']) > 0: + content = json_data['choices'][0].get('delta', {}).get('content', '') + if content: + accumulated_text += content + hass.bus.async_fire(f"{DOMAIN}_stream_token", {"event_id": event_id, "content": content, "full_content": accumulated_text}) + + except json.JSONDecodeError as e: + LOGGER.error(f"解析流式响应失败: {str(e)}") + continue + + hass.bus.async_fire(f"{DOMAIN}_stream_end", {"event_id": event_id, "full_content": accumulated_text}) + + hass.bus.async_fire(f"{DOMAIN}_response", {"type": "image_analysis", "content": accumulated_text, "success": True}) + return {"success": True, "event_id": event_id, "message": accumulated_text} + + except Exception as e: + error_msg = f"处理流式响应时出错: {str(e)}" + LOGGER.error(error_msg) + hass.bus.async_fire(f"{DOMAIN}_stream_error", {"event_id": event_id, "error": error_msg}) + return {"success": False, "message": error_msg} + else: + result = response.json() + if result.get("choices") and len(result["choices"]) > 0: + content = result["choices"][0].get("message", {}).get("content", "") + hass.bus.async_fire(f"{DOMAIN}_response", {"type": "image_analysis", "content": content, "success": True}) + return {"success": True, "message": content} + else: + error_msg = "No response from API" + return {"success": False, "message": error_msg} + except Exception as e: + error_msg = f"Image analysis failed: {str(e)}" + LOGGER.error(f"图像分析错误: {str(e)}") + return {"success": False, "message": error_msg} + + async def handle_video_analyzer(call: ServiceCall) -> None: + try: + config_entries = hass.config_entries.async_entries(DOMAIN) + if not config_entries: + raise ValueError("No ZhipuAI configuration found") + api_key = config_entries[0].data.get("api_key") + if not api_key: + raise ValueError("API key not found in configuration") + + video_file = call.data["video_file"] + if not os.path.isabs(video_file): + video_file = os.path.join(hass.config.config_dir, video_file) + + if not os.path.isfile(video_file): + LOGGER.error(f"视频文件未找到或是目录: {video_file}") + return {"success": False, "message": f"视频文件未找到或是目录: {video_file}"} + + try: + def read_video_file(): + with open(video_file, "rb") as f: + return f.read() + video_data = await hass.async_add_executor_job(read_video_file) + except IOError as e: + LOGGER.error(f"读取视频文件失败 {video_file}: {str(e)}") + return {"success": False, "message": f"读取视频文件失败: {str(e)}"} + + if call.data.get("model") != "glm-4v-plus": + LOGGER.warning("视频分析仅支持glm-4v-plus模型,强制使用glm-4v-plus") + + video_size = len(video_data) / (1024 * 1024) + if video_size > 20: + LOGGER.error(f"视频文件大小 ({video_size:.1f}MB) 超过20MB限制") + return {"success": False, "message": f"视频文件大小 ({video_size:.1f}MB) 超过20MB限制"} + + if not video_file.lower().endswith('.mp4'): + LOGGER.error("视频文件必须是MP4格式") + return {"success": False, "message": "视频文件必须是MP4格式"} + + base64_video = base64.b64encode(video_data).decode('utf-8') + + payload = { + "model": "glm-4v-plus", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "video_url", + "video_url": { + "url": base64_video + } + }, + { + "type": "text", + "text": call.data["message"] + } + ] + } + ], + "stream": call.data.get("stream", False) + } + + if CONF_TEMPERATURE in call.data: + payload["temperature"] = call.data[CONF_TEMPERATURE] + if CONF_MAX_TOKENS in call.data: + payload["max_tokens"] = call.data[CONF_MAX_TOKENS] + + headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {api_key}' + } + + try: + for attempt in range(3): + try: + response = await hass.async_add_executor_job( + lambda: requests.post( + ZHIPUAI_URL, + headers=headers, + json=payload, + stream=call.data.get("stream", False), + timeout=30 + ) + ) + response.raise_for_status() + break + except requests.exceptions.RequestException as e: + if attempt == 2: + raise + if response.status_code == 429: + await asyncio.sleep(2) + else: + raise + except requests.exceptions.RequestException as e: + LOGGER.error(f"API请求失败: {str(e)}") + return {"success": False, "message": f"API请求失败: {str(e)}"} + + if call.data.get("stream", False): + event_id = f"zhipuai_response_{int(time.time())}" + + try: + hass.bus.async_fire(f"{DOMAIN}_stream_start", {"event_id": event_id, "type": "video_analysis"}) + + accumulated_text = "" + for line in response.iter_lines(): + if line: + try: + line_text = line.decode('utf-8').strip() + if line_text.startswith('data: '): + line_text = line_text[6:] + if line_text == '[DONE]': + break + + json_data = json.loads(line_text) + if 'choices' in json_data and json_data['choices']: + choice = json_data['choices'][0] + if 'delta' in choice and 'content' in choice['delta']: + content = choice['delta']['content'] + accumulated_text += content + hass.bus.async_fire( + f"{DOMAIN}_stream_token", + { + "event_id": event_id, + "token": content, + "complete_text": accumulated_text + } + ) + except json.JSONDecodeError as e: + LOGGER.error(f"解析流式响应失败: {str(e)}") + continue + + hass.bus.async_fire(f"{DOMAIN}_stream_end", { + "event_id": event_id, + "complete_text": accumulated_text + }) + return {"success": True, "message": accumulated_text} + except Exception as e: + LOGGER.error(f"处理流式响应时出错: {str(e)}") + return {"success": False, "message": f"处理流式响应时出错: {str(e)}"} + else: + result = response.json() + if result.get("choices") and len(result["choices"]) > 0: + content = result["choices"][0].get("message", {}).get("content", "") + hass.bus.async_fire(f"{DOMAIN}_response", {"type": "video_analysis", "content": content, "success": True}) + return {"success": True, "message": content} + else: + error_msg = "No response from API" + return {"success": False, "message": error_msg} + except Exception as e: + error_msg = f"Video analysis failed: {str(e)}" + LOGGER.error(f"视频分析错误: {str(e)}") + return {"success": False, "message": error_msg} + + hass.services.async_register(DOMAIN, "image_analyzer", handle_image_analyzer, schema=IMAGE_ANALYZER_SCHEMA, supports_response=True) + + hass.services.async_register(DOMAIN, "video_analyzer", handle_video_analyzer, schema=VIDEO_ANALYZER_SCHEMA, supports_response=True) + + @callback + def async_unload_services() -> None: + hass.services.async_remove(DOMAIN, "image_analyzer") + hass.services.async_remove(DOMAIN, "video_analyzer") + + return async_unload_services diff --git a/custom_components/zhipuai/services.yaml b/custom_components/zhipuai/services.yaml new file mode 100644 index 0000000..6c88c6e --- /dev/null +++ b/custom_components/zhipuai/services.yaml @@ -0,0 +1,221 @@ +image_analyzer: + name: 图像分析 + description: 使用智谱GLM-4V模型分析图像 + fields: + model: + name: 模型 + description: '选择要使用的图像分析模型' + required: true + example: "glm-4v-flash" + default: "glm-4v-flash" + selector: + select: + options: + - "glm-4v-plus" + - "glm-4v" + - "glm-4v-flash" + message: + name: 提示词 + required: true + description: '给模型的提示词' + example: "请描述这张图片的内容" + selector: + text: + multiline: true + image_file: + name: 图片文件 + required: false + description: '本地图片路径(支持jpg、png、jpeg格式,最大5MB,最大分辨率6000x6000像素)' + example: "/config/www/tmp/front_door.jpg" + selector: + text: + multiline: false + image_entity: + name: 图片实体 + required: false + description: '要分析的图片或摄像头实体' + example: 'camera.front_door' + selector: + entity: + domain: ["image", "camera"] + multiple: false + temperature: + name: 温度 + required: false + description: '控制输出的随机性(0.1-1.0)。值越低,输出越稳定' + example: 0.8 + default: 0.8 + selector: + number: + min: 0.1 + max: 1.0 + step: 0.1 + max_tokens: + name: 最大令牌数 + required: false + description: '限制生成文本的最大长度' + example: 1024 + default: 1024 + selector: + number: + min: 1 + max: 1024 + step: 1 + stream: + name: 流式响应 + required: false + description: '是否使用流式响应(实时返回生成结果)' + example: false + default: false + selector: + boolean: {} + +video_analyzer: + name: 视频分析 + description: 使用智谱GLM-4V-Plus模型分析视频 + fields: + model: + name: 模型 + description: '视频分析模型 (仅支持 GLM-4V-Plus)' + required: false + default: "glm-4v-plus" + selector: + select: + options: + - "glm-4v-plus" + message: + name: 提示词 + required: true + description: '给模型的提示词' + example: "请描述这段视频的内容" + selector: + text: + multiline: true + video_file: + name: 视频文件 + required: true + description: '本地视频文件路径(支持mp4格式,建议时长不超过30秒)' + example: "/config/www/tmp/video.mp4" + selector: + text: + multiline: false + temperature: + name: 温度 + required: false + description: '控制输出的随机性(0.1-1.0)。值越低,输出越稳定' + example: 0.8 + default: 0.8 + selector: + number: + min: 0.1 + max: 1.0 + step: 0.1 + max_tokens: + name: 最大令牌数 + required: false + description: '限制生成文本的最大长度' + example: 1024 + default: 1024 + selector: + number: + min: 1 + max: 1024 + step: 1 + stream: + name: 流式响应 + required: false + description: '是否使用流式响应(实时返回生成结果)' + example: false + default: false + selector: + boolean: {} + +image_gen: + name: 图像生成 + description: 使用 CogView-3 模型生成图像 + fields: + prompt: + name: 图像描述 + description: 所需图像的文本描述 + required: true + example: "一只可爱的小猫咪" + selector: + text: + multiline: true + model: + name: 模型 + description: 选择要使用的模型版本 + required: false + default: cogview-3-flash + selector: + select: + options: + - label: CogView-3 Plus + value: cogview-3-plus + - label: CogView-3 + value: cogview-3 + - label: CogView-3 Flash (免费) + value: cogview-3-flash + size: + name: 图片尺寸 + description: 生成图片的尺寸大小 + required: false + default: 1024x1024 + selector: + select: + options: + - label: 1024x1024 + value: 1024x1024 + - label: 768x1344 + value: 768x1344 + - label: 864x1152 + value: 864x1152 + - label: 1344x768 + value: 1344x768 + - label: 1152x864 + value: 1152x864 + - label: 1440x720 + value: 1440x720 + - label: 720x1440 + value: 720x1440 + +web_search: + name: 联网搜索 + description: 使用智谱AI的web-search-pro工具进行联网搜索 + fields: + query: + name: 搜索内容 + description: '要搜索的内容' + required: true + example: "中国队奥运会拿了多少奖牌" + selector: + text: + multiline: true + stream: + name: 流式响应 + description: '是否使用流式响应(实时返回生成结果)' + required: false + default: false + selector: + boolean: {} + +entity_analysis: + name: 实体历史记录 + description: 获取实体的历史状态记录(如人在传感器、灯光、温度、湿度、光照度变化记录等) + fields: + entity_id: + name: 实体ID + description: 要获取历史记录的实体ID + required: true + selector: + entity: + multiple: true + days: + name: 天数 + description: 要获取的历史记录天数(1-15天) + default: 3 + selector: + number: + min: 1 + max: 15 + mode: box diff --git a/custom_components/zhipuai/translations/en.json b/custom_components/zhipuai/translations/en.json new file mode 100644 index 0000000..6e3dbf9 --- /dev/null +++ b/custom_components/zhipuai/translations/en.json @@ -0,0 +1,102 @@ +{ + "title": "Clear words of wisdom", + "config": { + "step": { + "user": { + "data": { + "name": "custom name", + "api_key": "API key" + }, + "description": "Get the key: [Click the link](https://www.bigmodel.cn/invite?icode=9niOGcfvBKiiCpCLI4tgtX3uFJ1nZ0jLLgipQkYjpcA%3D)" + }, + "reauth_confirm": { + "title": "Re-verify Zhipu AI", + "description": "Your Zhipu AI API key has expired, please enter a new API key", + "data": { + "api_key": "API key" + } + }, + "reconfigure_confirm": { + "title": "Reconfigure Zhipu AI", + "description": "Please enter new configuration information", + "data": { + "api_key": "API key" + } + }, + "history": { + "title": "Historical data analysis configuration", + "description": "Select the entities to be analyzed and the number of days of historical data", + "data": { + "history_entities": "Select entity", + "history_days": "Number of days of historical data (1-15 days)" + } + } + }, + "error": { + "cannot_connect": "Unable to connect to service", + "invalid_auth": "API key error", + "unknown": "unknown error", + "cooldown_too_small": "The cooling time value {value} is too small, please set a value greater than or equal to 0!", + "cooldown_too_large": "The cooling time value {value} is too large, please set a value less than or equal to 10!", + "model_not_found": "The specified model could not be found", + "invalid_api_key": "API Key format error" + }, + "abort": { + "already_configured": "The device has been configured", + "reauth_successful": "Re-authentication successful", + "reconfigure_successful": "Reconfiguration successful" + } + }, + "options": { + "step": { + "init": { + "data": { + "chat_model": "chat model", + "temperature": "temperature", + "max_tokens": "Maximum number of tiles", + "max_history_messages": "Maximum number of historical messages", + "top_p": "Top P", + "prompt": "prompt word template", + "max_tool_iterations": "Maximum number of tool iterations", + "cooldown_period": "Cooling time (seconds)", + "request_timeout": "Request timeout (seconds)", + "llm_hass_api": "Home Assistant LLM API", + "recommended": "Use recommended model settings", + "web_search": "Internet analytics search", + "history_analysis": "Entity historical data analysis" + }, + "data_description": { + "prompt": "Indicates how LLM should respond. This can be a template.", + "chat_model": "Please select the chat model you want to use. By default, please select the free universal 128K model. If you want a better experience, you can choose to support other paid models. The actual cost is not high. Please check the official website billing standards for details.", + "max_tokens": "Set the maximum number of tokens returned in the response", + "temperature": "Controls the randomness of the output (0-2)", + "top_p": "Control output diversity (0-1)", + "llm_hass_api": "Opt-in Home Assistant LLM", + "recommended": "Use recommended model settings", + "max_history_messages": "Set the maximum number of historical messages to retain. Function: Control the memory function of input content. The memory function can ensure smooth contextual dialogue. Generally, it is best to control home equipment within 5 times. It is effective for requests that cannot be processed smoothly. For other daily conversations, the threshold can be set to more than 10 times.", + "max_tool_iterations": "Set the maximum number of tool calls in a single session. Its function is to set the call threshold for the system LLM call request. If an error occurs, it can ensure that the system will not freeze. Especially for the design of various small hosts with weak performance, it is recommended to set it 20-30 times.", + "cooldown_period": "Set the minimum interval between two conversation requests (0-10 seconds). Function: The request will be delayed for a period of time before being sent. It is recommended to set it within 3 seconds to ensure that the content sending request fails due to frequency factors.", + "request_timeout": "Set the timeout for AI requests (10-120 seconds). Function: Control the maximum time to wait for AI response. You may need to increase this value when generating longer text. Suggestion: If it is a short quick response dialogue, or remove the AI ​​timeout setting, you can set it to about 10 seconds. If AI errors often occur, you can increase this value appropriately. Generally, 30 seconds is enough for a conversation. If you need to generate a long text of more than 1,000 words, it is recommended to set it for more than 60 seconds. If timeout errors often occur, you can increase this value appropriately." + } + }, + "history": { + "title": "Entity historical data analysis configuration", + "description": "Provides **entity historical data analysis** in scenarios that cannot be achieved by **Jinja2 template** (Home Assistant's template system) to ensure that AI understands and analyzes your device data. For example: it can be used to automatically help you analyze home security , personnel activity trajectories, daily life summary, UI text template introduction, etc.\n\n• Support **AI-assisted analysis** historical data (let AI understand and analyze your device data)\n• Provide intelligent decision support for **device management**\n• It is recommended to control within the range of **1 day historical data** for best results\n• **Special Reminder**: For environmental sensors that update frequently such as temperature, humidity, and illumination, please avoid selecting Prevent AI Overflow (you can set it according to the default 10 minutes)", + "data": { + "history_entities": "Select entity", + "history_days": "Get the range of days the entity has been in the repository (1-15 days)", + "history_interval": "Get the update time (minutes) of the entity in the repository" + } + } + }, + "error": { + "no_entities": "Please select at least one entity", + "invalid_days": "The number of historical data days must be between 1-15 days" + } + }, + "exceptions": { + "invalid_config_entry": { + "message": "The provided configuration entry is invalid. What you get is {config_entry}" + } + } +} \ No newline at end of file diff --git a/custom_components/zhipuai/translations/zh-Hans.json b/custom_components/zhipuai/translations/zh-Hans.json new file mode 100644 index 0000000..a5f3a45 --- /dev/null +++ b/custom_components/zhipuai/translations/zh-Hans.json @@ -0,0 +1,102 @@ +{ + "title": "智谱清言", + "config": { + "step": { + "user": { + "data": { + "name": "自定义名称", + "api_key": "API 密钥" + }, + "description": "获取密钥:[点击链接](https://www.bigmodel.cn/invite?icode=9niOGcfvBKiiCpCLI4tgtX3uFJ1nZ0jLLgipQkYjpcA%3D)" + }, + "reauth_confirm": { + "title": "重新验证 智谱AI", + "description": "您的 智谱AI API密钥已失效,请输入新的 API 密钥", + "data": { + "api_key": "API 密钥" + } + }, + "reconfigure_confirm": { + "title": "重新配置 智谱AI", + "description": "请输入新的配置信息", + "data": { + "api_key": "API 密钥" + } + }, + "history": { + "title": "历史数据分析配置", + "description": "选择需要分析的实体和历史数据天数", + "data": { + "history_entities": "选择实体", + "history_days": "历史数据天数 (1-15天)" + } + } + }, + "error": { + "cannot_connect": "无法连接到服务", + "invalid_auth": "API密钥错误", + "unknown": "未知错误", + "cooldown_too_small": "冷却时间值 {value} 太小,请设置大于等于 0 的值!", + "cooldown_too_large": "冷却时间值 {value} 太大,请设置小于等于 10 的值!", + "model_not_found": "找不到指定的模型", + "invalid_api_key": "API Key 格式错误" + }, + "abort": { + "already_configured": "设备已经配置", + "reauth_successful": "重新认证成功", + "reconfigure_successful": "重新配置成功" + } + }, + "options": { + "step": { + "init": { + "data": { + "chat_model": "聊天模型", + "temperature": "温度", + "max_tokens": "最大令牌数", + "max_history_messages": "最大历史消息数", + "top_p": "Top P", + "prompt": "提示词模板", + "max_tool_iterations": "最大工具迭代次数", + "cooldown_period": "冷却时间(秒)", + "request_timeout": "请求超时(秒)", + "llm_hass_api": "Home Assistant LLM API", + "recommended": "使用推荐的模型设置", + "web_search": "互联网分析搜索", + "history_analysis": "实体历史数据分析" + }, + "data_description": { + "prompt": "指示 LLM 应如何响应。这可以是一个模板。", + "chat_model": "请选择要使用的聊天模型,默认请选择免费通用128K模型,如需更好体验可选择支持其他付费模型,实际费用不高,具体请查看官网计费标准。", + "max_tokens": "设置响应中返回的最大令牌数", + "temperature": "控制输出的随机性(0-2)", + "top_p": "控制输出多样性(0-1)", + "llm_hass_api": "选择启用 Home Assistant LLM ", + "recommended": "使用推荐的模型设置", + "max_history_messages": "设置要保留的最大历史消息数。功能:控制输入内容的记忆功能,记忆功能可以保证上下文对话顺畅,一般控制家居设备最好控制在5次以内,对请求不能顺利进行有效,其他日常对话可以设置阈值在10次以上。", + "max_tool_iterations": "设置单次对话中的最大工具调用次数。其功能是对系统LLM调用请求设置调用阈值,如果出错可以保证系统不会卡死,尤其是对各种性能较弱的小主机的设计,建议设置20-30次。", + "cooldown_period": "设置两次对话请求的最小间隔时间(0-10秒)。作用:请求会延迟一段时间再发送,建议设置在3秒以内,保证因为频率因素导致内容发送请求失败。", + "request_timeout": "设置AI请求的超时时间(10-120秒)。作用:控制等待AI响应的最长时间,生成较长文本时可能需要增加此值。建议:如果是较短快速响应对话,可以设置10秒左右。如果出现AI报错,可以适当增加此值。如需生成超过1000字的长文本,建议设置60秒以上。" + } + }, + "history": { + "title": "实体历史数据分析 配置", + "description": "在**Jinja2模版**(Home Assistant的模板系统)无法实现的场景下提供**实体历史数据分析**,保证AI理解并分析您的设备数据,举例:可以用于自动化帮您分析家中安防、人员活动轨迹,日常生活总结,UI文本模版介绍等。\n\n• 支持**AI辅助分析**历史数据(让AI理解并分析您的设备数据)\n• 为**设备管理**提供智能决策支持\n• 建议控制在**1天历史数据**范围内以获得最佳效果\n• **特别提醒**:对于温湿度、光照度等频繁更新的环境传感器,请避免选择防止AI溢出(可以按照默认10分钟设置)", + "data": { + "history_entities": "选择实体", + "history_days": "获取实体在存储库中的天数范围 (1-15天)", + "history_interval": "获取实体在存储库中的更新时间(分钟)" + } + } + }, + "error": { + "no_entities": "请选择至少一个实体", + "invalid_days": "历史数据天数必须在 1-15天之间" + } +}, +"exceptions": { + "invalid_config_entry": { + "message": "提供的配置条目无效。得到的是 {config_entry}" + } +} +} \ No newline at end of file diff --git a/custom_components/zhipuai/web_search.py b/custom_components/zhipuai/web_search.py new file mode 100644 index 0000000..e82d909 --- /dev/null +++ b/custom_components/zhipuai/web_search.py @@ -0,0 +1,179 @@ +from __future__ import annotations + +import json +import logging +import uuid +import time +from typing import Any, Dict + +import voluptuous as vol +import requests +from homeassistant.core import HomeAssistant, ServiceCall, callback +from homeassistant.exceptions import HomeAssistantError, ServiceValidationError +from homeassistant.helpers import config_validation as cv + +from .const import ( + DOMAIN, + LOGGER, + ZHIPUAI_WEB_SEARCH_URL, + CONF_WEB_SEARCH, + DEFAULT_WEB_SEARCH +) + +WEB_SEARCH_SCHEMA = vol.Schema({ + vol.Required("query"): cv.string, + vol.Optional("stream", default=False): cv.boolean, +}) + +async def async_setup_web_search(hass: HomeAssistant) -> None: + + async def handle_web_search(call: ServiceCall) -> None: + try: + config_entries = hass.config_entries.async_entries(DOMAIN) + if not config_entries: + raise ValueError("未找到 ZhipuAI 配置") + + entry = config_entries[0] + if not entry.options.get(CONF_WEB_SEARCH, DEFAULT_WEB_SEARCH): + raise ValueError("联网搜索功能已关闭,请在配置中开启") + + api_key = entry.data.get("api_key") + if not api_key: + raise ValueError("在配置中找不到 API 密钥") + + query = call.data["query"] + stream = call.data.get("stream", False) + request_id = str(uuid.uuid4()) + + messages = [{"role": "user", "content": query}] + payload = { + "request_id": request_id, + "tool": "web-search-pro", + "stream": stream, + "messages": messages + } + + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } + + try: + response = await hass.async_add_executor_job( + lambda: requests.post( + ZHIPUAI_WEB_SEARCH_URL, + headers=headers, + json=payload, + stream=stream, + timeout=300 + ) + ) + response.raise_for_status() + except requests.exceptions.RequestException as e: + LOGGER.error(f"API请求失败: {str(e)}") + raise ServiceValidationError(f"API请求失败: {str(e)}") + + if stream: + event_id = f"zhipuai_response_{int(time.time())}" + + try: + hass.bus.async_fire(f"{DOMAIN}_stream_start", { + "event_id": event_id, + "type": "web_search" + }) + + accumulated_text = "" + for line in response.iter_lines(): + if line: + try: + line_text = line.decode('utf-8').strip() + if line_text.startswith('data: '): + line_text = line_text[6:] + if line_text == '[DONE]': + break + + json_data = json.loads(line_text) + if 'choices' in json_data and json_data['choices']: + choice = json_data['choices'][0] + if 'message' in choice and 'tool_calls' in choice['message']: + tool_calls = choice['message']['tool_calls'] + for tool_call in tool_calls: + if tool_call.get('type') == 'search_result': + search_results = tool_call.get('search_result', []) + for result in search_results: + content = result.get('content', '') + if content: + accumulated_text += content + "\n" + hass.bus.async_fire( + f"{DOMAIN}_stream_token", + { + "event_id": event_id, + "content": content, + "full_content": accumulated_text + } + ) + except json.JSONDecodeError as e: + LOGGER.error(f"解析流式响应失败: {str(e)}") + continue + + hass.bus.async_fire(f"{DOMAIN}_stream_end", { + "event_id": event_id, + "full_content": accumulated_text + }) + + hass.bus.async_fire(f"{DOMAIN}_response", { + "type": "web_search", + "content": accumulated_text, + "success": True + }) + return {"success": True, "event_id": event_id, "message": accumulated_text} + + except Exception as e: + error_msg = f"处理流式响应时出错: {str(e)}" + LOGGER.error(error_msg) + hass.bus.async_fire(f"{DOMAIN}_stream_error", { + "event_id": event_id, + "error": error_msg + }) + return {"success": False, "message": error_msg} + else: + result = response.json() + content = "" + if result.get("choices") and result["choices"][0].get("message", {}).get("tool_calls"): + tool_calls = result["choices"][0]["message"]["tool_calls"] + for tool_call in tool_calls: + if tool_call.get("type") == "search_result": + search_results = tool_call.get("search_result", []) + for result in search_results: + if result_content := result.get("content"): + content += result_content + "\n" + + if content: + hass.bus.async_fire(f"{DOMAIN}_response", { + "type": "web_search", + "content": content, + "success": True + }) + return {"success": True, "message": content} + else: + error_msg = "未从API获取到搜索结果" + return {"success": False, "message": error_msg} + + except Exception as e: + error_msg = f"Web search failed: {str(e)}" + LOGGER.error(f"网络搜索错误: {str(e)}") + return {"success": False, "message": error_msg} + + hass.services.async_register( + DOMAIN, + "web_search", + handle_web_search, + schema=WEB_SEARCH_SCHEMA, + supports_response=True + ) + + @callback + def async_unload_services() -> None: + hass.services.async_remove(DOMAIN, "web_search") + + return async_unload_services diff --git a/hacs.json b/hacs.json new file mode 100644 index 0000000..98422c7 --- /dev/null +++ b/hacs.json @@ -0,0 +1,5 @@ +{ + "name": "智谱清言", + "render_readme": true, + "homeassistant": "2024.11.0" +}