From 0cf6aa10a1ea2fafedf69628bd61ae835b6a8f6f Mon Sep 17 00:00:00 2001 From: KNOOP Date: Sun, 29 Dec 2024 16:44:06 +0800 Subject: [PATCH] Release v2024.12.31 --- .DS_Store | Bin 0 -> 6148 bytes README.md | 205 +++++++ custom_components/.DS_Store | Bin 0 -> 6148 bytes custom_components/zhipuai/.DS_Store | Bin 0 -> 6148 bytes custom_components/zhipuai/__init__.py | 86 +++ custom_components/zhipuai/ai_request.py | 23 + custom_components/zhipuai/config_flow.py | 397 +++++++++++++ custom_components/zhipuai/const.py | 57 ++ custom_components/zhipuai/conversation.py | 546 ++++++++++++++++++ custom_components/zhipuai/entity_analysis.py | 118 ++++ custom_components/zhipuai/image_gen.py | 109 ++++ custom_components/zhipuai/intents.py | 287 +++++++++ custom_components/zhipuai/intents.yaml | 178 ++++++ custom_components/zhipuai/manifest.json | 12 + custom_components/zhipuai/services.py | 440 ++++++++++++++ custom_components/zhipuai/services.yaml | 221 +++++++ .../zhipuai/translations/en.json | 96 +++ .../zhipuai/translations/zh-Hans.json | 96 +++ custom_components/zhipuai/web_search.py | 172 ++++++ hacs.json | 5 + 20 files changed, 3048 insertions(+) create mode 100644 .DS_Store create mode 100644 README.md create mode 100644 custom_components/.DS_Store create mode 100644 custom_components/zhipuai/.DS_Store create mode 100644 custom_components/zhipuai/__init__.py create mode 100644 custom_components/zhipuai/ai_request.py create mode 100644 custom_components/zhipuai/config_flow.py create mode 100644 custom_components/zhipuai/const.py create mode 100644 custom_components/zhipuai/conversation.py create mode 100644 custom_components/zhipuai/entity_analysis.py create mode 100644 custom_components/zhipuai/image_gen.py create mode 100644 custom_components/zhipuai/intents.py create mode 100644 custom_components/zhipuai/intents.yaml create mode 100644 custom_components/zhipuai/manifest.json create mode 100644 custom_components/zhipuai/services.py create mode 100644 custom_components/zhipuai/services.yaml create mode 100644 custom_components/zhipuai/translations/en.json create mode 100644 custom_components/zhipuai/translations/zh-Hans.json create mode 100644 custom_components/zhipuai/web_search.py create mode 100644 hacs.json diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..fac6e6356a491f9984e234a8ddd3ee43db21ac80 GIT binary patch literal 6148 zcmeHKL5tHs6n@i2cj_us*oDQ@fCn!{+wCfdm$cUNf)PEa)TSxjV3G+SJr4c z7cTDcJ-*MMduOun@=3nThLgopYCV!N@;CK~|2Ui#`9me zvKY!`5@d?&=>Tu>*0_DNTAlPeqJP|9cf_jKJwEA(+js7+*DZc{{novQ z8p9n>*%gBuc!QN1GYzt74$5d-@jFhSYy0ceqY|-V^$Wv z(XBbZSKcCtWHQI}jj+VD_6HR(!gAUJtXT!D0#<=-1!#W|=!CApQlr{BuuxY3LC zv + +--- +## 通知:本项目未经允许严禁商用,你可以隐晦但不能作为盈利手段。 + +### 📦 安装步骤 + +#### 1. HACS 添加自定义存储库 +在 Home Assistant 的 HACS 中,点击右上角的三个点,选择“自定义存储库”,并添加以下 URL: +``` +https://github.com/knoop7/zhipuai +``` + +#### 2. 添加智谱清言集成 +进入 Home Assistant 的“集成”页面,搜索并添加“智谱清言”。 + +#### 3. 配置 Key 🔑 +在配置页面中,你可以通过手机号登录获取 Key。获取后,直接填写 Key 使用,不需要进行额外验证。 +**注意**:建议你新建一个 Key,避免使用系统默认的 Key。 + +#### 4. 免费模型使用 💡 +智谱清言默认选择了免费模型,完全免费,不用担心收费。如果你有兴趣,还可以选择其他付费模型来体验更丰富的功能。 + +#### 5. 版本兼容性 📅 +请确保 Home Assistant 的版本不低于 8.0,因为智谱清言主要针对最新版本开发。如果遇到无法识别的实体问题,建议重启系统或更新至最新版本。 + +--- + +### 🛠 模型指令使用示例 +为了保证大家能使用舒畅,并且不出任何bug可以使用我的模版指令进行尝试 + +```` + +作为 Home Assistant 的智能家居管理者,你的名字叫“自定义”,我将为您提供智能家居信息和问题的解答。请查看以下可用设备、状态及操作示例。 + + +### 可用设备展示 + + +# 注意如果实体超过1000以上 +# 直接删掉这句话 + +### 今日油价: +```yaml +{% set sensor = 油价实体 %} +Sensor: {{ sensor.name }} +State: {{ sensor.state }} + +Attributes: +{% for attribute, value in sensor.attributes.items() %} + {{ attribute }}: {{ value }} +{% endfor %} +``` + +### 电费余额信息: +```yaml +{% set balance_sensor = 电费实体 %} + +{% if balance_sensor %} +当前余额: {{ balance_sensor.state }} {{ balance_sensor.attributes.unit_of_measurement }} +{% endif %} +``` + +### Tasmota能源消耗: +```yaml +{% set today_sensor = states.sensor.tasmota_energy_today %} +{% set yesterday_sensor = states.sensor.tasmota_energy_yesterday %} + +{% if today_sensor is not none and yesterday_sensor is not none %} +今日消耗: {{ today_sensor.state }} {{ today_sensor.attributes.unit_of_measurement }} +昨日消耗: {{ yesterday_sensor.state }} {{ yesterday_sensor.attributes.unit_of_measurement }} +{% endif %} +``` + + +### 此时天气: +```json +{% set entity_id = '天气实体' %} +{% set entity = states[entity_id] %} +{ + "state": "{{ entity.state }}", + "attributes": { + {% for attr in entity.attributes %} + {% if attr not in ['hourly_temperature', 'hourly_skycon', 'hourly_cloudrate', 'hourly_precipitation'] %} + "{{ attr }}": "{{ entity.attributes[attr] }}"{% if not loop.last %},{% endif %} + {% endif %} + {% endfor %} + } +} +```` + +或者这个模版指令 +```` +### 可用设备展示 +```csv +entity_id,name,state,category +{%- for entity in states if 'automation.' not in entity.entity_id and entity.state not in ['unknown'] and not ('device_tracker.' in entity.entity_id and ('huawei' in entity.entity_id or 'Samsung' in entity.entity_id)) and 'iphone' not in entity.entity_id and 'daily_english' not in entity.entity_id and 'lenovo' not in entity.entity_id and 'time' not in entity.entity_id and 'zone' not in entity.entity_id and 'n1' not in entity.entity_id and 'z470' not in entity.entity_id and 'lao_huang' not in entity.entity_id and 'lao_huang_li' not in entity.entity_id and 'input_text' not in entity.entity_id and 'conversation' not in entity.entity_id and 'camera' not in entity.entity_id and 'update' not in entity.entity_id and 'IPhone' not in entity.entity_id and 'mac' not in entity.entity_id and 'macmini' not in entity.entity_id and 'macbook' not in entity.entity_id and 'ups' not in entity.entity_id and 'OPENWRT' not in entity.entity_id and 'OPENWRT' not in entity.entity_id%} +{%- set category = '其他' %} +{%- if 'light.' in entity.entity_id %}{% set category = '灯' %} +{%- elif 'sensor.' in entity.entity_id and 'battery' in entity.entity_id %} + {% set category = '电池' %} +{%- elif 'sensor.' in entity.entity_id and 'sun' in entity.entity_id %} + {% set category = '太阳' %} +{%- elif 'sensor.' in entity.entity_id and ('motion' in entity.entity_id or 'presence' in entity.entity_id) %} + {% set category = '人体存在' %} +{%- elif 'sensor.' in entity.entity_id and ('motion' in entity.entity_id or 'presence' in entity.entity_id) %} + {% set category = '人体存在' %} +{%- elif 'climate.' in entity.entity_id %}{% set category = '空调' %} +{%- elif 'media_player.' in entity.entity_id %}{% set category = '媒体播放器' %} +{%- elif 'cover.' in entity.entity_id %}{% set category = '门窗' %} +{%- elif 'lock.' in entity.entity_id %}{% set category = '门锁' %} +{%- elif 'switch.' in entity.entity_id %}{% set category = '开关' %} +{%- elif 'sensor.' in entity.entity_id %}{% set category = '传感器' %} +{%- elif 'watering.' in entity.entity_id %}{% set category = '浇花器' %} +{%- elif 'fan.' in entity.entity_id %}{% set category = '风扇' %} +{%- elif 'air_quality.' in entity.entity_id %}{% set category = '空气质量' %} +{%- elif 'vacuum.' in entity.entity_id %}{% set category = '扫地机器人' %} +{%- elif 'person.' in entity.entity_id %}{% set category = '人员' %} +{%- elif 'binary_sensor.' in entity.entity_id and ('door' in entity.entity_id or 'window' in entity.entity_id) %}{% set category = '门窗' %} +{%- elif 'gas.' in entity.entity_id %}{% set category = '天然气' %} +{%- elif 'energy.' in entity.entity_id %}{% set category = '用电量' %} +{%- elif 'script.' in entity.entity_id %}{% set category = '脚本' %} +{%- elif 'scene.' in entity.entity_id %}{% set category = '场景' %} +{%- endif %} +{{- entity.entity_id }},{{ entity.name }},{{ entity.state }},{{ category }} +{%- endfor %} + +```` + +--- + +### 使用内置 API 公开实体 🌐 +你可以使用智谱清言内置的 API 来公开实体,并为其设置别名。通过重新命名实体,你可以避免使用系统默认名称造成的混乱,提升管理效率。 + +--- + +### 🚀 使用指南 + +1. **访问界面** + 打开 Home Assistant 仪表板,找到“智谱清言”集成卡片或对应的集成页面。 + +2. **输入指令** + 在集成页面或对话框中,输入自然语言指令,或使用语音助手下达命令。 + +3. **查看响应** + 系统会根据你的指令执行任务,设备状态变化将实时显示并反馈。 + +4. **探索功能** + 你可以尝试不同的指令来控制家中的智能设备,或查询相关状态。 + +--- + +### 📑 常用指令示例 + +- "打开客厅灯" +- "将卧室温度调到 22 度" +- "播放音乐" +- "明早 7 点提醒我备忘" +- "检查门锁状态" +- "看看全屋温度湿度“ + +--- + +### 🛠 Bug 处理 +如果你在使用过程中遇到持续的 Python 错误,建议重启对话框并重新加载环境。这样可以解决一些潜在的代码问题。 + +--- + +### 🗂 处理不被 Home Assistant 认可的实体 +如果 Home Assistant 中存在不被认可的实体,你可以将这些实体剔除出自动化控制的范围。通过在指令中添加 Jinja2 模板,可以有效避免 Python 的错误提示,杜绝潜在问题。 + +--- + +### 额外提示 + +- **系统版本要求**:智谱清言需要 Home Assistant 至少 8.0 版本支持。 +- **建议**:如果遇到兼容性问题,建议重启或更新系统。通常这能解决大多数问题。 +- **相关项目** 如果需要语音转文字可以使用免费在线AI模型集成,个人二次深度修改 ````https://github.com/knoop7/groqcloud_whisper```` + + +--- + +### 📊 实时状态 + +#### 当前时间:16:09:23,今日日期:2024-10-12。 + +#### 油价信息 ⛽ +- 92号汽油:7元/升 +- 95号汽油:7元/升 +- 98号汽油:8元/升 +预计下次油价调整时间为10月23日24时,油价可能继续上涨。 + +#### 电费余额 ⚡ +- 当前余额:27.5元 + +#### 今日能源消耗 💡 +- 今日消耗:4033.0 Wh +- 昨日消耗:7.558 kWh + +#### 今日新闻摘要 📰 +1. 民政部发布全国老年人口数据。 diff --git a/custom_components/.DS_Store b/custom_components/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..9c176340b0a170d0af19401acce1e0ddc6b353fe GIT binary patch literal 6148 zcmeHKy-ve05Wb_Sh}5McqhFx|6N{qFZa!YXih{WnjniZA~HQ{Ix*)dkaLbb-O;vrT(@;MlISl^ z$+w@<6K!ZqJso_1*J4XwP3`?+weEI(-brzBe(bU9s;X|g8kX|m_4?-J`R#JdtNDgk zzd2ZK8j;4u8E^)i0cXG&_&Eb|!9mKIqW8{#GvEw-G9c$eKog9HNiiQC=yC}FEHIq~ zy3`U96AYtaQiKJ<8Vb}5%ZwKJ!w?@i1J1xdV}OflRV}e8yIYU8CwFZ?yG0Wbza$C-`rr|Oft({} d*{Jp)I{czxQj{z*pTmLv5Xgjh=M4M;1Mfw|HCzAy literal 0 HcmV?d00001 diff --git a/custom_components/zhipuai/.DS_Store b/custom_components/zhipuai/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..eabf78c22cc35e57db25f7fdcc7b8e42042bfd95 GIT binary patch literal 6148 zcmeHKu}%Xq47H)dNicL|%*Nb-JygfQ+#Nxw1V~q|qGi8N=r8aKd<5Ucd$B=1ajFCZ zLdceq=Onh{d(tJbi0JfrF%cPyNQFjJDg(ml!KDK?9s#w)SW6=}vXKXA2Lk=YDdv8R zjGvJAqwMC7QFAw$-L}h>C4Tz4&vx0?^}K24u&K6Zmsc;_xAXn>i*MPl?ss#^2ptj( z1Ovf9Fc1vw(Z1)8gRiJ|6>`Q&!lu`@JxQ7=B!cm60|wCs-fshEpq!{~#7V4%;yp$*5p|1a>% z3>Nu)mly>D!N5ObKu7hgp5mkYZvFCkylWHMH5!F>9Tf=l-Xj15o+Bq^(bf}n*k#Ag UkWn~a+=1~BD1<~84EzEE?~Z6TL;wH) literal 0 HcmV?d00001 diff --git a/custom_components/zhipuai/__init__.py b/custom_components/zhipuai/__init__.py new file mode 100644 index 0000000..a3494cc --- /dev/null +++ b/custom_components/zhipuai/__init__.py @@ -0,0 +1,86 @@ +from __future__ import annotations +from homeassistant.config_entries import ConfigEntry +from homeassistant.const import CONF_API_KEY, Platform +from homeassistant.core import HomeAssistant, callback +from homeassistant.exceptions import ConfigEntryNotReady +from homeassistant.helpers.dispatcher import async_dispatcher_send +from .const import DOMAIN, LOGGER +from .intents import get_intent_handler, async_setup_intents +from .services import async_setup_services +from .web_search import async_setup_web_search +from .image_gen import async_setup_image_gen +from .entity_analysis import async_setup_entity_analysis + +PLATFORMS: list[Platform] = [Platform.CONVERSATION] + +class ZhipuAIConfigEntry: + def __init__(self, hass: HomeAssistant, config_entry: ConfigEntry): + self.hass = hass + self.config_entry = config_entry + self.api_key = config_entry.data[CONF_API_KEY] + self.options = config_entry.options + self._unsub_options_update_listener = None + self._cleanup_callbacks = [] + self.intent_handler = get_intent_handler(hass) + + @property + def entry_id(self): + return self.config_entry.entry_id + + @property + def title(self): + return self.config_entry.title + + async def async_setup(self) -> None: + self._unsub_options_update_listener = self.config_entry.add_update_listener( + self.async_options_updated + ) + + async def async_unload(self) -> None: + if self._unsub_options_update_listener is not None: + self._unsub_options_update_listener() + self._unsub_options_update_listener = None + for cleanup_callback in self._cleanup_callbacks: + cleanup_callback() + self._cleanup_callbacks.clear() + + def async_on_unload(self, func): + self._cleanup_callbacks.append(func) + + @callback + async def async_options_updated(self, hass: HomeAssistant, entry: ConfigEntry) -> None: + self.options = entry.options + async_dispatcher_send(hass, f"{DOMAIN}_options_updated", entry) + +async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + try: + zhipuai_entry = ZhipuAIConfigEntry(hass, entry) + await zhipuai_entry.async_setup() + hass.data.setdefault(DOMAIN, {})[entry.entry_id] = zhipuai_entry + + unload_services = await async_setup_services(hass) + zhipuai_entry.async_on_unload(unload_services) + + await async_setup_intents(hass) + + await async_setup_web_search(hass) + await async_setup_image_gen(hass) + await async_setup_entity_analysis(hass) + + await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) + return True + + except Exception as ex: + raise ConfigEntryNotReady from ex + +async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) + try: + zhipuai_entry = hass.data[DOMAIN].get(entry.entry_id) + if zhipuai_entry is not None and hasattr(zhipuai_entry, 'async_unload'): + await zhipuai_entry.async_unload() + except Exception: + pass + finally: + hass.data[DOMAIN].pop(entry.entry_id, None) + return unload_ok \ No newline at end of file diff --git a/custom_components/zhipuai/ai_request.py b/custom_components/zhipuai/ai_request.py new file mode 100644 index 0000000..6a6a307 --- /dev/null +++ b/custom_components/zhipuai/ai_request.py @@ -0,0 +1,23 @@ +import aiohttp +from aiohttp import TCPConnector +from homeassistant.exceptions import HomeAssistantError +from .const import LOGGER, ZHIPUAI_URL + +async def send_ai_request(api_key: str, payload: dict) -> dict: + try: + connector = TCPConnector(ssl=False) + async with aiohttp.ClientSession(connector=connector) as session: + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } + + async with session.post(ZHIPUAI_URL, json=payload, headers=headers) as response: + if response.status != 200: + raise HomeAssistantError(f"AI 返回状态 {response.status}") + result = await response.json() + return result + + except Exception as err: + LOGGER.error(f"与 AI 通信时出错: {err}") + raise HomeAssistantError(f"与 AI 通信时出错: {err}") diff --git a/custom_components/zhipuai/config_flow.py b/custom_components/zhipuai/config_flow.py new file mode 100644 index 0000000..919d5a4 --- /dev/null +++ b/custom_components/zhipuai/config_flow.py @@ -0,0 +1,397 @@ +from __future__ import annotations +from typing import Any +from types import MappingProxyType +import voluptuous as vol +import aiohttp +from homeassistant.core import HomeAssistant, callback +from homeassistant.config_entries import ( + ConfigEntry, + ConfigFlow, + ConfigFlowResult, + OptionsFlow, +) +from homeassistant import exceptions +from homeassistant.const import CONF_API_KEY, CONF_NAME, CONF_LLM_HASS_API +from homeassistant.helpers import llm +import homeassistant.helpers.config_validation as cv +from homeassistant.helpers.selector import ( + NumberSelector, + NumberSelectorConfig, + SelectOptionDict, + SelectSelector, + SelectSelectorConfig, + TemplateSelector, +) + +from .const import ( + CONF_PROMPT, + CONF_TEMPERATURE, + DEFAULT_NAME, + CONF_CHAT_MODEL, + CONF_MAX_TOKENS, + CONF_RECOMMENDED, + CONF_TOP_P, + CONF_MAX_HISTORY_MESSAGES, + DOMAIN, + RECOMMENDED_MAX_TOKENS, + RECOMMENDED_TEMPERATURE, + RECOMMENDED_TOP_P, + RECOMMENDED_MAX_HISTORY_MESSAGES, + CONF_MAX_TOOL_ITERATIONS, + CONF_COOLDOWN_PERIOD, + DEFAULT_MAX_TOOL_ITERATIONS, + DEFAULT_COOLDOWN_PERIOD, + CONF_WEB_SEARCH, + DEFAULT_WEB_SEARCH, + CONF_HISTORY_ANALYSIS, + CONF_HISTORY_ENTITIES, + CONF_HISTORY_DAYS, + DEFAULT_HISTORY_ANALYSIS, + DEFAULT_HISTORY_DAYS, + MAX_HISTORY_DAYS, +) + +ZHIPUAI_MODELS = [ + "GLM-4-Plus", + "GLM-4V-Plus", + "GLM-4-0520", + "GLM-4-Long", + "GLM-4-AirX", + "GLM-4-Air", + "GLM-4-FlashX", + "GLM-4-Flash", + "GLM-4V", + "GLM-4-AllTools", + "GLM-4", +] + +RECOMMENDED_CHAT_MODEL = "GLM-4-Flash" + +RECOMMENDED_OPTIONS = { + CONF_RECOMMENDED: True, + CONF_LLM_HASS_API: llm.LLM_API_ASSIST, + CONF_PROMPT: """您是 Home Assistant 的语音助手。 +如实回答有关世界的问题。 +以纯文本形式回答。保持简单明了。""", + CONF_MAX_HISTORY_MESSAGES: RECOMMENDED_MAX_HISTORY_MESSAGES, + CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL, + CONF_MAX_TOOL_ITERATIONS: DEFAULT_MAX_TOOL_ITERATIONS, + CONF_COOLDOWN_PERIOD: DEFAULT_COOLDOWN_PERIOD, +} + +class ZhipuAIConfigFlow(ConfigFlow, domain=DOMAIN): + VERSION = 1 + MINOR_VERSION = 0 + + def __init__(self) -> None: + self._reauth_entry: ConfigEntry | None = None + self._reconfigure_entry: ConfigEntry | None = None + + async def async_step_user( + self, + user_input: dict[str, Any] | None = None, + ) -> ConfigFlowResult: + errors = {} + if user_input is not None: + try: + await self._validate_api_key(user_input[CONF_API_KEY]) + return self.async_create_entry( + title=user_input[CONF_NAME], + data=user_input, + options=RECOMMENDED_OPTIONS, + ) + except UnauthorizedError: + errors["base"] = "invalid_auth" + except aiohttp.ClientError: + errors["base"] = "cannot_connect" + except ModelNotFound: + errors["base"] = "model_not_found" + except Exception: + errors["base"] = "unknown" + + return self.async_show_form( + step_id="user", + data_schema=vol.Schema({ + vol.Required(CONF_NAME, default=DEFAULT_NAME): cv.string, + vol.Required(CONF_API_KEY): cv.string, + }), + errors=errors, + ) + + async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> ConfigFlowResult: + self._reauth_entry = self._get_reauth_entry() + return await self.async_step_reauth_confirm() + + async def async_step_reauth_confirm( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + errors = {} + + if user_input is not None: + try: + await self._validate_api_key(user_input[CONF_API_KEY]) + assert self._reauth_entry is not None + return self.async_update_reload_and_abort( + self._reauth_entry, + data_updates={CONF_API_KEY: user_input[CONF_API_KEY]}, + reason="reauth_successful", + ) + except UnauthorizedError: + errors["base"] = "invalid_auth" + except aiohttp.ClientError: + errors["base"] = "cannot_connect" + except Exception: + errors["base"] = "unknown" + + return self.async_show_form( + step_id="reauth_confirm", + data_schema=vol.Schema({ + vol.Required(CONF_API_KEY): cv.string, + }), + errors=errors, + ) + + async def async_step_reconfigure(self, entry_data: Mapping[str, Any]) -> ConfigFlowResult: + self._reconfigure_entry = self._get_reconfigure_entry() + return await self.async_step_reconfigure_confirm() + + async def async_step_reconfigure_confirm( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + errors = {} + + if user_input is not None: + try: + await self._validate_api_key(user_input[CONF_API_KEY]) + assert self._reconfigure_entry is not None + return self.async_update_reload_and_abort( + self._reconfigure_entry, + data_updates={CONF_API_KEY: user_input[CONF_API_KEY]}, + reason="reconfigure_successful", + ) + except UnauthorizedError: + errors["base"] = "invalid_auth" + except aiohttp.ClientError: + errors["base"] = "cannot_connect" + except Exception: + errors["base"] = "unknown" + + return self.async_show_form( + step_id="reconfigure_confirm", + data_schema=vol.Schema({ + vol.Required(CONF_API_KEY): cv.string, + }), + errors=errors, + ) + + async def _validate_api_key(self, api_key: str) -> None: + url = "https://open.bigmodel.cn/api/paas/v4/chat/completions" + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}" + } + data = { + "model": RECOMMENDED_CHAT_MODEL, + "messages": [{"role": "user", "content": "你好"}] + } + + async with aiohttp.ClientSession() as session: + try: + async with session.post(url, headers=headers, json=data) as response: + if response.status == 200: + return + elif response.status == 401: + raise UnauthorizedError() + else: + response_json = await response.json() + error = response_json.get("error", {}) + error_message = error.get("message", "") + if "model not found" in error_message.lower(): + raise ModelNotFound() + else: + raise InvalidAPIKey() + except aiohttp.ClientError as e: + raise + + @staticmethod + @callback + def async_get_options_flow(config_entry: ConfigEntry) -> ZhipuAIOptionsFlow: + return ZhipuAIOptionsFlow(config_entry) + + +class ZhipuAIOptionsFlow(OptionsFlow): + def __init__(self, config_entry: ConfigEntry) -> None: + self._config_entry = config_entry + self._data = {} + + async def async_step_init( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + errors = {} + if user_input is not None: + try: + cooldown_period = user_input.get(CONF_COOLDOWN_PERIOD) + if cooldown_period is not None: + cooldown_period = float(cooldown_period) + if cooldown_period < 0: + errors[CONF_COOLDOWN_PERIOD] = "cooldown_too_small" + elif cooldown_period > 10: + errors[CONF_COOLDOWN_PERIOD] = "cooldown_too_large" + + if not errors: + self._data.update(user_input) + if user_input.get(CONF_HISTORY_ANALYSIS): + return await self.async_step_history() + return self.async_create_entry(title="", data=self._data) + except ValueError: + errors["base"] = "invalid_option" + + schema = vol.Schema(zhipuai_config_option_schema(self.hass, self._config_entry.options)) + return self.async_show_form( + step_id="init", + data_schema=schema, + errors=errors, + ) + + async def async_step_history( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + errors = {} + current_options = self._config_entry.options + + if user_input is not None: + try: + days = user_input.get(CONF_HISTORY_DAYS, DEFAULT_HISTORY_DAYS) + if days < 1 or days > MAX_HISTORY_DAYS: + errors[CONF_HISTORY_DAYS] = "invalid_days" + if not user_input.get(CONF_HISTORY_ENTITIES): + errors[CONF_HISTORY_ENTITIES] = "no_entities" + + if not errors: + self._data.update(user_input) + return self.async_create_entry(title="", data=self._data) + except ValueError: + errors["base"] = "invalid_option" + + entities = {} + for entity in self.hass.states.async_all(): + friendly_name = entity.attributes.get("friendly_name", entity.entity_id) + entities[entity.entity_id] = f"{friendly_name} ({entity.entity_id})" + + return self.async_show_form( + step_id="history", + data_schema=vol.Schema({ + vol.Required( + CONF_HISTORY_ENTITIES, + default=current_options.get(CONF_HISTORY_ENTITIES, []) + ): SelectSelector( + SelectSelectorConfig( + options=[{"value": k, "label": v} for k, v in entities.items()], + multiple=True, + mode="dropdown", + custom_value=False, + ) + ), + vol.Required( + CONF_HISTORY_DAYS, + default=current_options.get(CONF_HISTORY_DAYS, DEFAULT_HISTORY_DAYS) + ): vol.All( + vol.Coerce(int), + vol.Range(min=1, max=MAX_HISTORY_DAYS), + ), + }), + errors=errors, + ) + + +def zhipuai_config_option_schema( + hass: HomeAssistant, + options: dict[str, Any] | MappingProxyType[str, Any], +) -> dict: + hass_apis = [SelectOptionDict(label="No", value="none")] + hass_apis.extend( + SelectOptionDict(label=api.name, value=api.id) + for api in llm.async_get_apis(hass) + ) + + schema = { + vol.Optional( + CONF_PROMPT, + description={"suggested_value": options.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT)}, + ): TemplateSelector(), + vol.Optional( + CONF_LLM_HASS_API, + description={"suggested_value": options.get(CONF_LLM_HASS_API)}, + default="none", + ): SelectSelector(SelectSelectorConfig(options=hass_apis)), + vol.Required( + CONF_RECOMMENDED, + default=options.get(CONF_RECOMMENDED, False) + ): bool, + vol.Optional( + CONF_MAX_HISTORY_MESSAGES, + description={"suggested_value": options.get(CONF_MAX_HISTORY_MESSAGES)}, + default=RECOMMENDED_MAX_HISTORY_MESSAGES, + ): int, + vol.Optional( + CONF_CHAT_MODEL, + description={"suggested_value": options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)}, + default=RECOMMENDED_CHAT_MODEL, + ): SelectSelector(SelectSelectorConfig(options=ZHIPUAI_MODELS)), + vol.Optional( + CONF_MAX_TOOL_ITERATIONS, + description={"suggested_value": options.get(CONF_MAX_TOOL_ITERATIONS, DEFAULT_MAX_TOOL_ITERATIONS)}, + default=DEFAULT_MAX_TOOL_ITERATIONS, + ): int, + vol.Optional( + CONF_COOLDOWN_PERIOD, + description={"suggested_value": options.get(CONF_COOLDOWN_PERIOD, DEFAULT_COOLDOWN_PERIOD)}, + default=DEFAULT_COOLDOWN_PERIOD, + ): vol.All( + vol.Coerce(float), + vol.Range(min=0, max=10), + msg="冷却时间必须在0到10秒之间" + ), + vol.Optional( + CONF_WEB_SEARCH, + default=DEFAULT_WEB_SEARCH, + ): bool, + vol.Optional( + CONF_HISTORY_ANALYSIS, + default=options.get(CONF_HISTORY_ANALYSIS, DEFAULT_HISTORY_ANALYSIS), + description={"suggested_value": options.get(CONF_HISTORY_ANALYSIS, DEFAULT_HISTORY_ANALYSIS)}, + ): bool, + } + + if not options.get(CONF_RECOMMENDED, False): + schema.update({ + vol.Optional( + CONF_MAX_TOKENS, + description={"suggested_value": options.get(CONF_MAX_TOKENS)}, + default=RECOMMENDED_MAX_TOKENS, + ): int, + vol.Optional( + CONF_TOP_P, + description={"suggested_value": options.get(CONF_TOP_P)}, + default=RECOMMENDED_TOP_P, + ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), + vol.Optional( + CONF_TEMPERATURE, + description={"suggested_value": options.get(CONF_TEMPERATURE)}, + default=RECOMMENDED_TEMPERATURE, + ): NumberSelector(NumberSelectorConfig(min=0, max=2, step=0.05)), + }) + + return schema + +class UnknownError(exceptions.HomeAssistantError): + pass + +class UnauthorizedError(exceptions.HomeAssistantError): + pass + +class InvalidAPIKey(exceptions.HomeAssistantError): + pass + +class ModelNotFound(exceptions.HomeAssistantError): + pass \ No newline at end of file diff --git a/custom_components/zhipuai/const.py b/custom_components/zhipuai/const.py new file mode 100644 index 0000000..af0e586 --- /dev/null +++ b/custom_components/zhipuai/const.py @@ -0,0 +1,57 @@ +import logging + +DOMAIN = "zhipuai" +LOGGER = logging.getLogger(__name__) +NAME = "自定义名称" +DEFAULT_NAME = "智谱清言" +CONF_RECOMMENDED = "recommended" +CONF_PROMPT = "prompt" +CONF_CHAT_MODEL = "chat_model" +RECOMMENDED_CHAT_MODEL = "GLM-4-Flash" +CONF_MAX_TOKENS = "max_tokens" +RECOMMENDED_MAX_TOKENS = 350 +CONF_TOP_P = "top_p" +RECOMMENDED_TOP_P = 0.7 +CONF_TEMPERATURE = "temperature" +RECOMMENDED_TEMPERATURE = 0.4 +CONF_MAX_HISTORY_MESSAGES = "max_history_messages" +RECOMMENDED_MAX_HISTORY_MESSAGES = 5 + +CONF_MAX_TOOL_ITERATIONS = "max_tool_iterations" +DEFAULT_MAX_TOOL_ITERATIONS = 20 +CONF_COOLDOWN_PERIOD = "cooldown_period" +DEFAULT_COOLDOWN_PERIOD = 3 + +CONF_WEB_SEARCH = "web_search" +DEFAULT_WEB_SEARCH = False + +ZHIPUAI_URL = "https://open.bigmodel.cn/api/paas/v4/chat/completions" + +ZHIPUAI_WEB_SEARCH_URL = "https://open.bigmodel.cn/api/paas/v4/tools" +CONF_WEB_SEARCH_STREAM = "web_search_stream" +DEFAULT_WEB_SEARCH_STREAM = False + +ZHIPUAI_IMAGE_GEN_URL = "https://open.bigmodel.cn/api/paas/v4/images/generations" +CONF_IMAGE_GEN = "image_gen" +DEFAULT_IMAGE_GEN = False + +CONF_IMAGE_SIZE = "image_size" +DEFAULT_IMAGE_SIZE = "1024x1024" + +IMAGE_SIZES = [ + "1024x1024", + "768x1344", + "864x1152", + "1344x768", + "1152x864", + "1440x720", + "720x1440" +] + + +CONF_HISTORY_ANALYSIS = "history_analysis" +CONF_HISTORY_ENTITIES = "history_entities" +CONF_HISTORY_DAYS = "history_days" +DEFAULT_HISTORY_ANALYSIS = False +DEFAULT_HISTORY_DAYS = 1 +MAX_HISTORY_DAYS = 15 \ No newline at end of file diff --git a/custom_components/zhipuai/conversation.py b/custom_components/zhipuai/conversation.py new file mode 100644 index 0000000..c16fc76 --- /dev/null +++ b/custom_components/zhipuai/conversation.py @@ -0,0 +1,546 @@ +import json +import asyncio +import time +import re +from datetime import datetime, timedelta +from typing import Any, Literal, TypedDict, Dict, List, Optional +from voluptuous_openapi import convert +from homeassistant.components import assist_pipeline, conversation +from homeassistant.components.conversation import trace +from homeassistant.config_entries import ConfigEntry +from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, MATCH_ALL, ATTR_ENTITY_ID +from homeassistant.core import HomeAssistant, Context +from homeassistant.helpers import entity_registry as er +from homeassistant.helpers import device_registry as dr, intent, llm, template, entity_registry +from homeassistant.helpers.entity_platform import AddEntitiesCallback +from homeassistant.exceptions import HomeAssistantError +from homeassistant.components.recorder import get_instance +from homeassistant.components.recorder.history import get_significant_states +from homeassistant.util import ulid +from .ai_request import send_ai_request +from .intents import IntentHandler, extract_intent_info +from .const import ( + CONF_CHAT_MODEL, + CONF_MAX_TOKENS, + CONF_PROMPT, + CONF_TEMPERATURE, + CONF_TOP_P, + CONF_MAX_HISTORY_MESSAGES, + DOMAIN, + RECOMMENDED_CHAT_MODEL, + RECOMMENDED_MAX_TOKENS, + RECOMMENDED_TEMPERATURE, + RECOMMENDED_MAX_HISTORY_MESSAGES, + RECOMMENDED_TOP_P, + CONF_MAX_TOOL_ITERATIONS, + CONF_COOLDOWN_PERIOD, + DEFAULT_MAX_TOOL_ITERATIONS, + DEFAULT_COOLDOWN_PERIOD, + LOGGER, + CONF_HISTORY_ANALYSIS, + CONF_HISTORY_ENTITIES, + CONF_HISTORY_DAYS, + DEFAULT_HISTORY_DAYS, +) + + +class ChatCompletionMessageParam(TypedDict, total=False): + role: str + content: str | None + name: str | None + tool_calls: list["ChatCompletionMessageToolCallParam"] | None + +class Function(TypedDict, total=False): + name: str + arguments: str + +class ChatCompletionMessageToolCallParam(TypedDict): + id: str + type: str + function: Function + +class ChatCompletionToolParam(TypedDict): + type: str + function: dict[str, Any] + +def _format_tool(tool: llm.Tool, custom_serializer: Any | None) -> ChatCompletionToolParam: + tool_spec = { + "name": tool.name, + "parameters": convert(tool.parameters, custom_serializer=custom_serializer), + } + if tool.description: + tool_spec["description"] = tool.description + return ChatCompletionToolParam(type="function", function=tool_spec) + +def is_service_call(user_input: str) -> bool: + patterns = { + "control": ["让", "请", "帮我", "麻烦", "把", "将"], + "action": { + "turn_on": ["打开", "开启", "启动", "激活", "运行", "执行"], + "turn_off": ["关闭", "关掉", "停止"], + "toggle": ["切换"], + "press": ["按", "按下", "点击"], + "select": ["选择", "下一个", "上一个", "第一个", "最后一个"], + "trigger": ["触发", "调用"], + "media": ["暂停", "继续播放", "播放", "停止", "下一首", "下一曲", "下一个", "切歌", "换歌","上一首", "上一曲", "上一个", "返回上一首", "音量"] + } + } + return bool(user_input and (any(k in user_input for k in patterns["control"]) or any(k in v for v in patterns["action"].values() for k in v))) + +def extract_service_info(user_input: str, hass: HomeAssistant) -> Optional[Dict[str, Any]]: + def find_entity(domain: str, text: str) -> Optional[str]: + text = text.lower() + for e_id in hass.states.async_entity_ids(domain): + state = hass.states.get(e_id) + friendly_name = state.attributes.get("friendly_name", "").lower() + entity_name = e_id.split(".")[1].lower() + if (text in entity_name or text in friendly_name or + entity_name in text or friendly_name in text): + return e_id + return None + + def clean_text(text: str, patterns: List[str]) -> str: + control_words = ["让", "请", "帮我", "麻烦", "把", "将"] + for word in patterns + control_words: + text = text.replace(word, "") + return text.strip() + + if not is_service_call(user_input): + return None + + media_patterns = { + "暂停": "media_pause", + "继续播放": "media_play", + "播放": "media_play", + "停止": "media_stop", + "下一首": "media_next_track", + "下一曲": "media_next_track", + "下一个": "media_next_track", + "切歌": "media_next_track", + "换歌": "media_next_track", + "上一首": "media_previous_track", + "上一曲": "media_previous_track", + "上一个": "media_previous_track", + "返回上一首": "media_previous_track", + "音量": "volume_set" + } + for pattern, service in media_patterns.items(): + if pattern in user_input.lower(): + if entity_id := find_entity("media_player", user_input): + if service == "volume_set": + volume_match = re.search(r'(\d+)', user_input) + if volume_match: + volume = int(volume_match.group(1)) / 100 + return {"domain": "media_player", "service": service, "data": {"entity_id": entity_id, "volume_level": volume}} + return {"domain": "media_player", "service": service, "data": {"entity_id": entity_id}} + + button_patterns = ["按", "按下", "点击"] + if any(p in user_input for p in button_patterns): + if entity_id := (re.search(r'(button\.\w+)', user_input).group(1) if re.search(r'(button\.\w+)', user_input) + else find_entity("button", clean_text(user_input, button_patterns))): + return {"domain": "button", "service": "press", "data": {"entity_id": entity_id}} + + select_patterns = { + "下一个": ("select_next", True), + "上一个": ("select_previous", True), + "第一个": ("select_first", False), + "最后一个": ("select_last", False), + "选择": ("select_option", False) + } + if any(p in user_input for p in select_patterns.keys()): + if entity_id := find_entity("select", user_input): + pattern = next((k for k in select_patterns.keys() if k in user_input), "选择") + service, cycle = select_patterns[pattern] + return {"domain": "select", "service": service, "data": {"entity_id": entity_id, "cycle": cycle}} + + automation_patterns = ["触发", "调用", "执行", "运行", "启动"] + if any(p in user_input for p in automation_patterns): + name = clean_text(user_input, automation_patterns + ["脚本", "自动化", "场景"]) + if "脚本" in user_input.lower(): + if entity_id := find_entity("script", name): + return {"domain": "script", "service": "turn_on", "data": {"entity_id": entity_id}} + for domain, service in [("automation", "trigger"), ("script", "turn_on"), ("scene", "turn_on")]: + if entity_id := find_entity(domain, name): + return {"domain": domain, "service": service, "data": {"entity_id": entity_id}} + + return None + +class ZhipuAIConversationEntity(conversation.ConversationEntity, conversation.AbstractConversationAgent): + _attr_has_entity_name = True + _attr_name = None + _attr_response = "" + + def __init__(self, entry: ConfigEntry, hass: HomeAssistant) -> None: + self.entry = entry + self.hass = hass + self.history: dict[str, list[ChatCompletionMessageParam]] = {} + self._attr_unique_id = entry.entry_id + self._attr_device_info = dr.DeviceInfo( + identifiers={(DOMAIN, entry.entry_id)}, + name=entry.title, + manufacturer="智谱清言", + model="ChatGLM Pro", + entry_type=dr.DeviceEntryType.SERVICE, + ) + if self.entry.options.get(CONF_LLM_HASS_API): + self._attr_supported_features = conversation.ConversationEntityFeature.CONTROL + self.last_request_time = 0 + self.max_tool_iterations = min(entry.options.get(CONF_MAX_TOOL_ITERATIONS, DEFAULT_MAX_TOOL_ITERATIONS), 5) + self.cooldown_period = entry.options.get(CONF_COOLDOWN_PERIOD, DEFAULT_COOLDOWN_PERIOD) + self.llm_api = None + self.intent_handler = IntentHandler(hass) + self.entity_registry = er.async_get(hass) + self.device_registry = dr.async_get(hass) + self.service_call_attempts = 0 + self._attr_native_value = "就绪" + self._attr_extra_state_attributes = {"response": ""} + + def _filter_response_content(self, content: str) -> str: + content = re.sub(r'```[\s\S]*?```', '', content) + content = re.sub(r'{[\s\S]*?}', '', content) + content = re.sub(r'(?m)^(import|from|def|class)\s+.*$', '', content) + if not content.strip(): + return "抱歉,暂不支持该操作。如果问题持续,可能需要调整指令。" + return content.strip() + + @property + def supported_languages(self) -> list[str] | Literal["*"]: + return MATCH_ALL + + async def async_added_to_hass(self) -> None: + await super().async_added_to_hass() + assist_pipeline.async_migrate_engine(self.hass, "conversation", self.entry.entry_id, self.entity_id) + conversation.async_set_agent(self.hass, self.entry, self) + self.entry.async_on_unload(self.entry.add_update_listener(self._async_entry_update_listener)) + + async def async_will_remove_from_hass(self) -> None: + conversation.async_unset_agent(self.hass, self.entry) + await super().async_will_remove_from_hass() + + async def async_process(self, user_input: conversation.ConversationInput) -> conversation.ConversationResult: + if user_input.context and user_input.context.id and user_input.context.id.startswith(f"{DOMAIN}_service_call"): + return None + + current_time = time.time() + if current_time - self.last_request_time < self.cooldown_period: + await asyncio.sleep(self.cooldown_period - (current_time - self.last_request_time)) + self.last_request_time = time.time() + + intent_response = intent.IntentResponse(language=user_input.language) + + if is_service_call(user_input.text): + service_info = extract_service_info(user_input.text, self.hass) + if service_info: + result = await self.intent_handler.call_service( + service_info["domain"], + service_info["service"], + service_info["data"] + ) + if result["success"]: + intent_response.async_set_speech(result["message"]) + else: + intent_response.async_set_error( + intent.IntentResponseErrorCode.NO_VALID_TARGETS, + result["message"] + ) + return conversation.ConversationResult( + response=intent_response, + conversation_id=user_input.conversation_id + ) + + intent_info = extract_intent_info(user_input.text, self.hass) + if intent_info: + result = await self.intent_handler.handle_intent(intent_info) + if result["success"]: + intent_response.async_set_speech(result["message"]) + else: + intent_response.async_set_error( + intent.IntentResponseErrorCode.NO_VALID_TARGETS, + result["message"] + ) + return conversation.ConversationResult( + response=intent_response, + conversation_id=user_input.conversation_id + ) + + options = self.entry.options + tools: list[ChatCompletionToolParam] | None = None + user_name: str | None = None + llm_context = llm.LLMContext( + platform=DOMAIN, + context=user_input.context, + user_prompt=user_input.text, + language=user_input.language, + assistant=conversation.DOMAIN, + device_id=user_input.device_id, + ) + + try: + if options.get(CONF_LLM_HASS_API) and options[CONF_LLM_HASS_API] != "none": + self.llm_api = await llm.async_get_api(self.hass, options[CONF_LLM_HASS_API], llm_context) + tools = [_format_tool(tool, self.llm_api.custom_serializer) for tool in self.llm_api.tools][:8] + except HomeAssistantError as err: + intent_response.async_set_error(intent.IntentResponseErrorCode.UNKNOWN, f"获取 LLM API 时出错,将继续使用基本功能:{err}") + + if user_input.conversation_id is None: + conversation_id = ulid.ulid_now() + messages = [] + elif user_input.conversation_id in self.history: + conversation_id = user_input.conversation_id + messages = self.history[conversation_id] + else: + conversation_id = user_input.conversation_id + messages = [] + + max_history_messages = options.get(CONF_MAX_HISTORY_MESSAGES, RECOMMENDED_MAX_HISTORY_MESSAGES) + use_history = len(messages) < max_history_messages + + if user_input.context and user_input.context.user_id and (user := await self.hass.auth.async_get_user(user_input.context.user_id)): + user_name = user.name + try: + er = entity_registry.async_get(self.hass) + exposed_entities = [ + er.async_get(entity_id) for entity_id in self.hass.states.async_entity_ids() + if er.async_get(entity_id) and not er.async_get(entity_id).hidden + ] + + prompt_parts = [ + template.Template( + llm.BASE_PROMPT + options.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT), + self.hass, + ).async_render( + { + "ha_name": self.hass.config.location_name, + "user_name": user_name, + "llm_context": llm_context, + "exposed_entities": exposed_entities, + }, + parse_result=False, + ) + ] + + if self.entry.options.get(CONF_HISTORY_ANALYSIS): + entities = self.entry.options.get(CONF_HISTORY_ENTITIES, []) + days = self.entry.options.get(CONF_HISTORY_DAYS, DEFAULT_HISTORY_DAYS) + + if entities: + end_time = datetime.now() + start_time = end_time - timedelta(days=days) + + history_text = [] + history_text.append(f"\n以下是询问者所关注的实体的历史数据分析({days}天内):") + + instance = get_instance(self.hass) + history_data = await instance.async_add_executor_job( + get_significant_states, + self.hass, + start_time, + end_time, + entities, + None, + True, + True + ) + + for entity_id in entities: + state = self.hass.states.get(entity_id) + if state is None: + continue + + if entity_id in history_data: + states = history_data[entity_id] + history_text.append(f"\n{entity_id}:") + for state in states: + if state is None: + continue + history_text.append( + f"- {state.state} ({state.last_updated.strftime('%Y-%m-%d %H:%M:%S')})" + ) + else: + history_text.append(f"\n{entity_id} (当前状态):") + history_text.append( + f"- {state.state} ({state.last_updated.strftime('%Y-%m-%d %H:%M:%S')})" + ) + + prompt_parts.append("\n".join(history_text)) + + except template.TemplateError as err: + content_message = f"抱歉,我的模板有问题: {err}" + filtered_content = self._filter_response_content(content_message) + intent_response.async_set_error(intent.IntentResponseErrorCode.UNKNOWN, filtered_content) + return conversation.ConversationResult(response=intent_response, conversation_id=conversation_id) + + if self.llm_api: + prompt_parts.append(self.llm_api.api_prompt) + + prompt = "\n".join(prompt_parts) + LOGGER.info("Prompt Parts: %s", prompt_parts) + LOGGER.info("生成的 Prompt: %s", prompt) + + messages = [ + ChatCompletionMessageParam(role="system", content=prompt), + *(messages if use_history else []), + ChatCompletionMessageParam(role="user", content=user_input.text), + ] + if len(messages) > max_history_messages + 1: + messages = [messages[0]] + messages[-(max_history_messages):] + + api_key = self.entry.data[CONF_API_KEY] + try: + for iteration in range(self.max_tool_iterations): + payload = { + "model": options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), + "messages": messages[-10:], + "max_tokens": min(options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS), 1000), + "top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P), + "temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE), + "request_id": conversation_id, + } + if tools: + payload["tools"] = tools + + + result = await send_ai_request(api_key, payload) + response = result["choices"][0]["message"] + + messages.append(response) + self._attr_response = response.get("content", "") + + tool_calls = response.get("tool_calls") + + if not tool_calls: + break + + tool_call_failed = False + for tool_call in tool_calls: + try: + tool_input = llm.ToolInput( + tool_name=tool_call["function"]["name"], + tool_args=json.loads(tool_call["function"]["arguments"]), + ) + + tool_response = await self._handle_tool_call(tool_input, user_input.text) + + if isinstance(tool_response, dict) and "error" in tool_response: + raise Exception(tool_response["error"]) + + formatted_response = json.dumps(tool_response) + messages.append( + ChatCompletionMessageParam( + role="tool", + tool_call_id=tool_call["id"], + content=formatted_response, + ) + ) + except Exception as e: + content_message = f"操作执行失败: {str(e)}" + messages.append( + ChatCompletionMessageParam( + role="tool", + tool_call_id=tool_call["id"], + content=content_message, + ) + ) + tool_call_failed = True + + if tool_call_failed and self.service_call_attempts >= 3: + return await self._fallback_to_hass_llm(user_input, conversation_id) + + + final_content = response.get("content", "").strip() + + if is_service_call(user_input.text): + service_info = extract_service_info(final_content, self.hass) + if service_info: + try: + await self.hass.services.async_call( + service_info["domain"], + service_info["service"], + service_info["data"], + blocking=True + ) + message = f"成功执行{service_info['domain']}:{service_info['data']['entity_id']}" + intent_response.response_text = message + return conversation.ConversationResult( + response=intent_response, + conversation_id=conversation_id + ) + except Exception as e: + intent_response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + f"执行服务失败:{str(e)}" + ) + return conversation.ConversationResult( + response=intent_response, + conversation_id=conversation_id + ) + + filtered_content = self._filter_response_content(final_content) + + self.history[conversation_id] = messages + intent_response.async_set_speech(filtered_content) + self._attr_extra_state_attributes["response"] = filtered_content + return conversation.ConversationResult(response=intent_response, conversation_id=conversation_id) + + except Exception as err: + return await self._fallback_to_hass_llm(user_input, conversation_id) + + async def _handle_tool_call(self, tool_input: llm.ToolInput, user_input: str) -> Dict[str, Any]: + try: + if self.llm_api and hasattr(self.llm_api, "async_call_tool"): + try: + result = await self.llm_api.async_call_tool(tool_input) + if isinstance(result, dict) and "error" not in result: + return result + except Exception as e: + LOGGER.warning("LLM API调用失败: %s", str(e)) + + if is_service_call(user_input): + service_info = extract_service_info(user_input, self.hass) + if service_info: + result = await self.intent_handler.call_service( + service_info["domain"], + service_info["service"], + service_info["data"] + ) + return result + else: + return {"error": "无法解析服务调用信息"} + + return {"error": "无法处理该工具调用"} + + except Exception as e: + return {"error": f"处理工具调用时发生错误: {str(e)}"} + + async def _fallback_to_hass_llm(self, user_input: conversation.ConversationInput, conversation_id: str) -> conversation.ConversationResult: + try: + agent = await conversation.async_get_agent(self.hass) + result = await agent.async_process(user_input) + return result + except Exception as err: + intent_response = intent.IntentResponse(language=user_input.language) + intent_response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + f"很抱歉,我现在无法正确处理您的请求。请稍后再试。错误: {err}" + ) + return conversation.ConversationResult(response=intent_response, conversation_id=conversation_id) + + @staticmethod + async def _async_entry_update_listener(hass: HomeAssistant, entry: ConfigEntry) -> None: + entity = hass.data[DOMAIN].get(entry.entry_id) + if entity: + entity.entry = entry + +async def async_setup_entry( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, +) -> None: + entity = ZhipuAIConversationEntity(config_entry, hass) + async_add_entities([entity]) + hass.data.setdefault(DOMAIN, {})[config_entry.entry_id] = entity + +async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + if unload_ok := await hass.config_entries.async_unload_platforms(entry, ["conversation"]): + hass.data[DOMAIN].pop(entry.entry_id, None) + return unload_ok \ No newline at end of file diff --git a/custom_components/zhipuai/entity_analysis.py b/custom_components/zhipuai/entity_analysis.py new file mode 100644 index 0000000..8b25539 --- /dev/null +++ b/custom_components/zhipuai/entity_analysis.py @@ -0,0 +1,118 @@ + +from __future__ import annotations + +import json +import voluptuous as vol +from datetime import datetime, timedelta +from typing import List + +from homeassistant.core import HomeAssistant, ServiceCall, State +from homeassistant.exceptions import ServiceValidationError +from homeassistant.helpers import entity_registry as er, config_validation as cv +from homeassistant.components.recorder import get_instance +from homeassistant.components.recorder.history import get_significant_states + +from .const import DOMAIN, LOGGER + +ENTITY_ANALYSIS_SCHEMA = vol.Schema({ + vol.Required("entity_id"): vol.Any(cv.entity_id, [cv.entity_id]), + vol.Optional("days", default=3): vol.All( + vol.Coerce(int), + vol.Range(min=1, max=15) + ) +}) + +async def async_setup_entity_analysis(hass: HomeAssistant) -> None: + + + async def handle_entity_analysis(call: ServiceCall) -> dict: + + try: + entity_ids = call.data["entity_id"] + if isinstance(entity_ids, str): + entity_ids = [entity_ids] + + days = call.data.get("days", 3) + + entity_registry = er.async_get(hass) + valid_entity_ids = [] + current_states = {} + + for entity_id in entity_ids: + state = hass.states.get(entity_id) + if state is None: + LOGGER.warning("实体 %s 不存在", entity_id) + continue + + if not entity_registry.async_get(entity_id): + LOGGER.info("实体 %s 未注册但存在,将只获取当前状态", entity_id) + current_states[entity_id] = state + else: + valid_entity_ids.append(entity_id) + + if not valid_entity_ids and not current_states: + error_msg = "没有找到任何有效的实体" + LOGGER.error(error_msg) + return {"success": False, "message": error_msg} + + end_time = datetime.now() + start_time = end_time - timedelta(days=days) + history_text = [] + + if valid_entity_ids: + instance = get_instance(hass) + history_data = await instance.async_add_executor_job( + get_significant_states, + hass, + start_time, + end_time, + valid_entity_ids, + None, + True, + True + ) + + if history_data: + for entity_id in valid_entity_ids: + if entity_id not in history_data: + continue + + for state in history_data[entity_id]: + if state is None: + continue + history_text.append( + f"{entity_id}, {state.state}, {state.last_updated.strftime('%Y-%m-%d %H:%M:%S')}" + ) + + for entity_id, state in current_states.items(): + history_text.append( + f"{entity_id}, {state.state}, {state.last_updated.strftime('%Y-%m-%d %H:%M:%S')}" + ) + + message = ( + f"分析时间范围: {start_time.strftime('%Y-%m-%d %H:%M:%S')} 至 " + f"{end_time.strftime('%Y-%m-%d %H:%M:%S')}\n" + f"总计 {len(history_text)} 条记录\n\n" + "实体ID, 状态值, 更新时间\n" + f"{chr(10).join(history_text)}" + ) + + LOGGER.info("生成的历史记录文本:\n%s", message) + + return { + "success": True, + "message": message + } + + except Exception as e: + error_msg = f"获取历史记录失败: {str(e)}" + LOGGER.error(error_msg) + return {"success": False, "message": error_msg} + + hass.services.async_register( + DOMAIN, + "entity_analysis", + handle_entity_analysis, + schema=ENTITY_ANALYSIS_SCHEMA, + supports_response=True + ) \ No newline at end of file diff --git a/custom_components/zhipuai/image_gen.py b/custom_components/zhipuai/image_gen.py new file mode 100644 index 0000000..a19cec3 --- /dev/null +++ b/custom_components/zhipuai/image_gen.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +import json +import os +import uuid +import aiohttp +import voluptuous as vol +from homeassistant.core import HomeAssistant, ServiceCall +from homeassistant.exceptions import ServiceValidationError + +from .const import ( + DOMAIN, + LOGGER, + ZHIPUAI_IMAGE_GEN_URL, + IMAGE_SIZES, + DEFAULT_IMAGE_SIZE, +) + +IMAGE_GEN_SCHEMA = vol.Schema({ + vol.Required("prompt"): str, + vol.Optional("model", default="cogview-3-flash"): vol.In(["cogview-3-plus", "cogview-3", "cogview-3-flash"]), + vol.Optional("size", default=DEFAULT_IMAGE_SIZE): vol.In(IMAGE_SIZES), +}) + +async def async_setup_image_gen(hass: HomeAssistant) -> None: + async def handle_image_gen(call: ServiceCall) -> None: + try: + config_entries = hass.config_entries.async_entries(DOMAIN) + if not config_entries: + raise ValueError("未找到 ZhipuAI 配置") + api_key = config_entries[0].data.get("api_key") + if not api_key: + raise ValueError("在配置中找不到 API 密钥") + + prompt = call.data["prompt"] + model = call.data.get("model", "cogview-3-flash") + size = call.data.get("size", DEFAULT_IMAGE_SIZE) + + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } + + payload = { + "model": model, + "prompt": prompt, + "size": size + } + + try: + async with aiohttp.ClientSession() as session: + async with session.post( + ZHIPUAI_IMAGE_GEN_URL, + headers=headers, + json=payload, + timeout=300 + ) as response: + response.raise_for_status() + result = await response.json() + + if not result.get("data") or not result["data"][0].get("url"): + raise ValueError("API 未返回有效的图片 URL") + + image_url = result["data"][0]["url"] + + # 下载图片并保存到本地 + www_dir = hass.config.path("www") + if not os.path.exists(www_dir): + os.makedirs(www_dir) + + filename = f"zhipuai_image_{uuid.uuid4().hex[:8]}.jpg" + filepath = os.path.join(www_dir, filename) + + async with aiohttp.ClientSession() as session: + async with session.get(image_url) as img_response: + img_response.raise_for_status() + with open(filepath, "wb") as f: + f.write(await img_response.read()) + + local_url = f"/local/{filename}" + + hass.bus.async_fire(f"{DOMAIN}_response", { + "type": "image_gen", + "content": local_url, + "success": True + }) + + return { + "success": True, + "message": local_url, + "original_url": image_url + } + + except aiohttp.ClientError as e: + LOGGER.error(f"API请求失败: {str(e)}") + raise ServiceValidationError(f"API请求失败: {str(e)}") + + except Exception as e: + error_msg = f"图像生成失败: {str(e)}" + LOGGER.error(error_msg) + return {"success": False, "message": error_msg} + + hass.services.async_register( + DOMAIN, + "image_gen", + handle_image_gen, + schema=IMAGE_GEN_SCHEMA, + supports_response=True + ) diff --git a/custom_components/zhipuai/intents.py b/custom_components/zhipuai/intents.py new file mode 100644 index 0000000..c196d24 --- /dev/null +++ b/custom_components/zhipuai/intents.py @@ -0,0 +1,287 @@ +from __future__ import annotations +import re +import os +import yaml +from datetime import timedelta +from typing import Any, Dict, List, Optional, Set +import voluptuous as vol +from homeassistant.components import camera +from homeassistant.components.conversation import DOMAIN as CONVERSATION_DOMAIN +from homeassistant.components.lock import LockState +from homeassistant.components.timer import ( + ATTR_DURATION, ATTR_REMAINING, + CONF_DURATION, CONF_ICON, + DOMAIN as TIMER_DOMAIN, + SERVICE_START, SERVICE_PAUSE, SERVICE_CANCEL +) +from homeassistant.const import ATTR_DEVICE_CLASS, ATTR_ENTITY_ID +from homeassistant.core import Context, HomeAssistant, ServiceResponse, State +from homeassistant.helpers import area_registry, device_registry, entity_registry, intent +from .const import DOMAIN, LOGGER + +INTENT_CAMERA_ANALYZE = "ZhipuAICameraAnalyze" +INTENT_WEB_SEARCH = "ZhipuAIWebSearch" +INTENT_NEVERMIND = "nevermind" +SERVICE_PROCESS = "process" +ERROR_NO_CAMERA = "no_camera" +ERROR_NO_RESPONSE = "no_response" +ERROR_SERVICE_CALL = "service_call_error" +ERROR_NO_QUERY = "no_query" + +async def async_setup_intents(hass: HomeAssistant) -> None: + yaml_path = os.path.join(os.path.dirname(__file__), "intents.yaml") + if os.path.exists(yaml_path): + with open(yaml_path, 'r', encoding='utf-8') as f: + intents_config = yaml.safe_load(f) + LOGGER.info("从 %s 加载的 intent 配置", yaml_path) + + intent.async_register(hass, CameraAnalyzeIntent()) + intent.async_register(hass, WebSearchIntent()) + + +class CameraAnalyzeIntent(intent.IntentHandler): + intent_type = INTENT_CAMERA_ANALYZE + slot_schema = {vol.Required("camera_name"): str, vol.Required("question"): str} + + def __init__(self): + super().__init__() + yaml_path = os.path.join(os.path.dirname(__file__), "intents.yaml") + if os.path.exists(yaml_path): + with open(yaml_path, 'r', encoding='utf-8') as f: + self.config = yaml.safe_load(f).get(INTENT_CAMERA_ANALYZE, {}) + + async def async_handle(self, intent_obj: intent.Intent) -> intent.IntentResponse: + slots = self.async_validate_slots(intent_obj.slots) + camera_name = slots.get("camera_name", {}).get("value", "") + question = slots.get("question", {}).get("value", "") + + LOGGER.info("Camera analyze intent info - 原始插槽: %s", slots) + LOGGER.info("Camera analyze intent info - Camera: %s, Question: %s", camera_name, question) + + target_camera = next((e for e in intent_obj.hass.states.async_all(camera.DOMAIN) + if camera_name.lower() in (e.name.lower(), e.entity_id.lower())), None) + response = intent.IntentResponse(intent=intent_obj, language="zh-cn") + + if self.config and "speech" in self.config: + response.async_set_speech(self.config["speech"]["text"]) + + return (self._set_error_response(response, ERROR_NO_CAMERA, f"找不到名为 {camera_name} 的摄像头") + if not target_camera else await self._handle_camera_analysis(intent_obj.hass, response, target_camera, question)) + + async def _handle_camera_analysis(self, hass, response, target_camera, question) -> intent.IntentResponse: + try: + result = await hass.services.async_call(DOMAIN, "image_analyzer", + service_data={"model": "glm-4v-flash", "temperature": 0.8, "max_tokens": 1024, + "stream": False, "image_entity": target_camera.entity_id, "message": question}, + blocking=True, return_response=True) + return (self._set_speech_response(response, result.get("message", "")) + if result and isinstance(result, dict) and result.get("success", False) else + self._set_error_response(response, ERROR_SERVICE_CALL, result.get("message", "服务调用失败")) + if result and isinstance(result, dict) else + self._set_error_response(response, ERROR_NO_RESPONSE, "未能获取到有效的分析结果")) + except Exception as e: + return self._set_error_response(response, ERROR_SERVICE_CALL, f"服务调用出错:{str(e)}") + + def _set_error_response(self, response, code, message) -> intent.IntentResponse: + response.async_set_error(code=code, message=message) + return response + + def _set_speech_response(self, response, message) -> intent.IntentResponse: + response.async_set_speech(message) + return response + + +class WebSearchIntent(intent.IntentHandler): + intent_type = INTENT_WEB_SEARCH + slot_schema = {vol.Required("query"): str} + + def __init__(self): + super().__init__() + yaml_path = os.path.join(os.path.dirname(__file__), "intents.yaml") + if os.path.exists(yaml_path): + with open(yaml_path, 'r', encoding='utf-8') as f: + self.config = yaml.safe_load(f).get(INTENT_WEB_SEARCH, {}) + + async def async_handle(self, intent_obj: intent.Intent) -> intent.IntentResponse: + slots = self.async_validate_slots(intent_obj.slots) + query = slots.get("query", {}).get("value", "") + + LOGGER.info("Web 搜索 intent 信息 - 原始槽:%s", slots) + LOGGER.info("Web 搜索 意图 信息 - 提取的查询: %s", query) + + if not query: + response = intent.IntentResponse(intent=intent_obj, language="zh-cn") + return self._set_error_response(response, ERROR_NO_QUERY, "未提供搜索内容") + + return await self._handle_web_search(intent_obj.hass, intent_obj, query) + + async def _handle_web_search( + self, hass: HomeAssistant, intent_obj: intent.Intent, query: str + ) -> intent.IntentResponse: + response = intent.IntentResponse(intent=intent_obj, language="zh-cn") + + try: + LOGGER.info("Web search service call - Query: %s", query) + result = await hass.services.async_call( + DOMAIN, + "web_search", + service_data={ + "query": query, + "stream": False + }, + blocking=True, + return_response=True + ) + + if result and isinstance(result, dict): + LOGGER.info("Web search result: %s", result) + if result.get("success", False): + return self._set_speech_response(response, result.get("message", "")) + return self._set_error_response( + response, ERROR_SERVICE_CALL, result.get("message", "搜索服务调用失败") + ) + + return self._set_error_response( + response, ERROR_NO_RESPONSE, "未能获取到有效的搜索结果" + ) + + except Exception as e: + return self._set_error_response( + response, ERROR_SERVICE_CALL, f"搜索服务调用出错:{str(e)}" + ) + + def _set_error_response(self, response, code, message) -> intent.IntentResponse: + response.async_set_error(code=code, message=message) + return response + + def _set_speech_response(self, response, message) -> intent.IntentResponse: + if len(message.encode('utf-8')) > 12 * 1024: + message = message[:12 * 1024].rsplit(' ', 1)[0] + "..." + response.async_set_speech(message) + return response + + +def extract_intent_info(user_input: str, hass: HomeAssistant) -> Optional[Dict[str, Any]]: + entity_match = re.search(r'([\w_]+\.[\w_]+)', user_input) + entity_id = entity_match.group(1) if entity_match else None + domain = entity_id.split('.')[0] if entity_id else None + + intent_mappings = { + r'打开|开启|解锁|turn on|open|unlock': 'turn_on', + r'关闭|关掉|锁定|turn off|close|lock': 'turn_off', + r'切换|toggle': 'toggle', + r'停止|暂停|stop|pause': 'stop', + r'继续|resume': 'start', + r'设置|调整|set|adjust': 'set', + r'没事|算了|再见|闭嘴|退下|nevermind|bye': INTENT_NEVERMIND + } + + action = next((action for pattern, action in intent_mappings.items() + if re.search(pattern, user_input)), 'turn_on') + + action = 'lock' if domain == 'lock' and action == 'turn_on' else \ + 'unlock' if domain == 'lock' and action == 'turn_off' else action + + intent_data = {'domain': domain, 'action': action, 'data': {'entity_id': entity_id}} if entity_id else None + + + if action == 'set' and intent_data: + number_match = re.search(r'(\d+)', user_input) + value = int(number_match.group(1)) if number_match else None + if value is not None: + intent_data['data'].update({ + 'position': value if domain == 'cover' else + min(255, value * 255 // 100) if domain == 'light' else + value if domain == 'climate' else None + }) + + return intent_data + +class IntentHandler: + def __init__(self, hass: HomeAssistant): + self.hass = hass + self.area_reg = area_registry.async_get(hass) + self.device_reg = device_registry.async_get(hass) + self.entity_reg = entity_registry.async_get(hass) + + async def call_service(self, domain: str, service: str, data: Dict[str, Any]) -> Dict[str, Any]: + try: + entity_id = data.pop("entity_id", None) + entity = self.hass.states.get(entity_id) + friendly_name = entity.attributes.get("friendly_name") if entity else "设备" + service_data = {**data, "entity_id": entity_id} if entity_id else data + await self.hass.services.async_call(domain, service, service_data, blocking=True) + return {"success": True, "message": f"您好,我已执行 {friendly_name}", "data": service_data} + except Exception as e: + return {"success": False, "message": str(e), "data": data} + + async def handle_intent(self, intent_info: Dict[str, Any]) -> Dict[str, Any]: + domain = intent_info.get('domain') + action = intent_info.get('action', '') + data = intent_info.get('data', {}) + name = data['name'].get('value') if isinstance(data.get('name'), dict) else data.get('name') + area = data['area'].get('value') if isinstance(data.get('area'), dict) else data.get('area') + entity_id = data.get('entity_id') + + return await ( + self.handle_nevermind_intent() if domain == CONVERSATION_DOMAIN and action == INTENT_NEVERMIND else + self.call_service(domain, action, data) if entity_id and await self._validate_service_for_entity(domain, action, entity_id) else + self.handle_cover_intent(action, name, area, data) if domain == "cover" else + self.handle_lock_intent(action, name, area, data) if domain == "lock" else + self.handle_timer_intent(action, name, area, data) if domain == TIMER_DOMAIN else + self.call_service(domain, action, data) + ) + + async def handle_nevermind_intent(self) -> Dict[str, Any]: + try: + await self.hass.services.async_call(CONVERSATION_DOMAIN, SERVICE_PROCESS, {"text": "再见"}, blocking=True) + return {"success": True, "message": "再见!", "close_conversation": True} + except Exception as e: + return {"success": False, "message": str(e)} + + async def _validate_service_for_entity(self, domain: str, service: str, entity_id: str) -> bool: + state = self.hass.states.get(entity_id) + return bool( + state and entity_id in self.hass.states.async_entity_ids() and + (supported_features := state.attributes.get("supported_features", 0)) and + not ( + (domain == "fan" and service == "set_percentage" and not (supported_features & 1)) or + (domain == "cover" and ((service == "close" and not (supported_features & 2)) or + (service == "set_position" and not (supported_features & 4)))) or + (domain == "lock" and ((service == "unlock" and not (supported_features & 1)) or + (service == "lock" and not (supported_features & 2)))) + ) + ) + + async def handle_timer_intent(self, action: str, name: str, area: str, data: Dict[str, Any]) -> Dict[str, Any]: + timer_name = name if name else "default" + entity_id = f"timer.{timer_name}" + return await ( + self.call_service(TIMER_DOMAIN, SERVICE_START, {"duration": data.get('duration'), "entity_id": entity_id}) + if action == "start" and data.get('duration') else + self.call_service(TIMER_DOMAIN, SERVICE_PAUSE, {"entity_id": entity_id}) + if action == "pause" and name else + self.call_service(TIMER_DOMAIN, SERVICE_CANCEL, {"entity_id": entity_id}) + if action == "cancel" and name else + {"success": False, "message": f"不支持的定时器操作: {action}"} + ) + + async def handle_cover_intent(self, action: str, name: str, area: str, data: Dict[str, Any]) -> Dict[str, Any]: + entity_id = data.get('entity_id') + return ( + {"success": False, "message": "未指定窗帘实体"} if not entity_id else + await self.call_service("cover", "set_cover_position", + {"entity_id": entity_id, "position": data.get('position')}) + if action == "set" and data.get('position') is not None else + await self.call_service("cover", action, {"entity_id": entity_id}) + ) + + async def handle_lock_intent(self, action: str, name: str, area: str, data: Dict[str, Any]) -> Dict[str, Any]: + entity_id = data.get('entity_id') + return ( + {"success": False, "message": "未指定门锁实体"} if not entity_id else + await self.call_service("lock", action, {"entity_id": entity_id}) + ) + +def get_intent_handler(hass: HomeAssistant) -> IntentHandler: + return IntentHandler(hass) \ No newline at end of file diff --git a/custom_components/zhipuai/intents.yaml b/custom_components/zhipuai/intents.yaml new file mode 100644 index 0000000..f929cd0 --- /dev/null +++ b/custom_components/zhipuai/intents.yaml @@ -0,0 +1,178 @@ +ZHIPUAI_CoverGetStateIntent: + data: + - sentences: + - "[]{name}(是|是不是){cover_states:state}[吗|不]" + response: one_yesno + requires_context: + domain: cover + slots: + domain: cover + + - sentences: + - "[{area}][有|有没有]{cover_classes:device_class}[是|是不是]{cover_states:state}[吗|不]" + response: any + slots: + domain: cover + + - sentences: + - "[][]{cover_classes:device_class}[是|是不是]都[是]{cover_states:state}[吗|不]" + - "{cover_classes:device_class}[是|是不是]都[是]{cover_states:state}[吗|不]" + response: all + slots: + domain: cover + + - sentences: + - "[]{cover_classes:device_class}[是]{cover_states:state}" + - "[]{cover_classes:device_class}[是]{cover_states:state}" + response: which + slots: + domain: cover + + - sentences: + - "[{area}]{cover_classes:device_class}[是]{cover_states:state}" + - "[]{cover_classes:device_class}[是]{cover_states:state}" + response: how_many + slots: + domain: cover + +ZHIPUAI_CoverSetPositionIntent: + data: + - sentences: + - "(||) [position] " + requires_context: + domain: cover + slots: + domain: cover + + - sentences: + - "(||) {cover_classes:device_class}[ position] " + slots: + domain: cover + + - sentences: + - "[] (||) [position] " + requires_context: + domain: cover + slots: + domain: cover + + - sentences: + - "[]{cover_classes:device_class}(||)[ position] " + slots: + domain: cover + + +HassGetWeather: + data: + - sentences: + - "[现在|当前][的]天气[怎么样]" + - "查看[现在|当前][的]天气" + - "天气[预报|情况]" + slots: + domain: weather + + - sentences: + - "[查询|查看][]{name}[的]天气[怎么样]" + - "[现在|当前][]{name}[的]天气[如何|状况]" + - "[]{name}天气[预报]" + requires_context: + domain: weather + +HassGetState: + data: + - sentences: + - "[查询|查看][]{name}[的]状态" + - "[]{name}[现在|当前]是什么状态" + - "[]{name}[怎么样|如何]" + slots: + domain: all + + - sentences: + - "[]{name}[的锁](是不是|有没有){lock_states:state}" + - "[]{name}[的锁][是|有]{lock_states:state}[吗|不]" + response: one_yesno + requires_context: + domain: lock + slots: + domain: lock + +ZhipuAIWebSearch: + data: + - sentences: + - "搜索{query}" + - "帮我联网{query}" + - "帮我上网查{query}" + - "联网搜索{query}" + - "联网查找{query}" + - "联网查询{query}" + - "在网上搜索{query}" + - "在网上查找{query}" + - "在网上查询{query}" + - "上网搜索{query}" + - "上网查找{query}" + - "上网查询{query}" + - "网上搜索{query}" + - "网上查找{query}" + - "网上查询{query}" + - "互联网搜索{query}" + - "互联网查找{query}" + - "互联网查询{query}" + - "百度{query}" + - "谷歌{query}" + - "必应{query}" + - "搜一下{query}" + - "查一下{query}" + - "帮我找找{query}" + - "帮忙搜索{query}" + - "帮忙查找{query}" + - "帮忙查询{query}" + - "search for {query}" + - "search {query}" + - "find {query}" + - "look up {query}" + - "google {query}" + - "bing {query}" + slots: + query: + type: text + example: "中国队奥运会奖牌数" + +CameraAnalyzeIntent: + data: + - sentences: + - "分析{camera_name}看到了什么{question}" + - "分析{camera_name}拍到了什么{question}" + - "分析{camera_name}的画面{question}" + - "分析{camera_name}的内容{question}" + - "分析{camera_name}的图像{question}" + - "分析{camera_name}的视频{question}" + - "帮我看看{camera_name}{question}" + - "帮我分析{camera_name}{question}" + - "帮我识别{camera_name}{question}" + - "查看{camera_name}的画面{question}" + - "查看{camera_name}拍到什么{question}" + - "检查{camera_name}的画面{question}" + - "观察{camera_name}的画面{question}" + - "识别{camera_name}的画面{question}" + - "看看{camera_name}拍到了什么{question}" + - "看一下{camera_name}的画面{question}" + - "告诉我{camera_name}拍到了什么{question}" + - "描述{camera_name}的画面{question}" + - "解释{camera_name}的画面{question}" + - "分析一下{camera_name}的画面{question}" + - "analyze what {camera_name} sees {question}" + - "check what {camera_name} captured {question}" + - "describe what {camera_name} sees {question}" + slots: + camera_name: + type: text + example: "客厅摄像头" + question: + type: text + example: "有什么异常吗" + speech: + text: 正在分析摄像头内容 + action: + service: image_processing.scan + data: + entity_id: "{{ camera_name }}" \ No newline at end of file diff --git a/custom_components/zhipuai/manifest.json b/custom_components/zhipuai/manifest.json new file mode 100644 index 0000000..eecf9fe --- /dev/null +++ b/custom_components/zhipuai/manifest.json @@ -0,0 +1,12 @@ +{ + "domain": "zhipuai", + "name": "\u667a\u8c31\u6e05\u8a00", + "codeowners": ["@knoop7"], + "config_flow": true, + "documentation": "https://github.com/knoop7/zhipuai", + "integration_type": "hub", + "iot_class": "cloud_polling", + "issue_tracker": "https://github.com/knoop7/zhipuai/issues", + "requirements": [], + "version": "2024.12.31" +} diff --git a/custom_components/zhipuai/services.py b/custom_components/zhipuai/services.py new file mode 100644 index 0000000..792fecd --- /dev/null +++ b/custom_components/zhipuai/services.py @@ -0,0 +1,440 @@ +from __future__ import annotations + +import base64 +import io +import os +import os.path +import numpy as np +from PIL import Image +import voluptuous as vol +from homeassistant.core import HomeAssistant, ServiceCall, callback +from homeassistant.helpers import config_validation as cv +from homeassistant.components import camera +from homeassistant.components.camera import Image as CameraImage +from homeassistant.exceptions import ServiceValidationError +import requests +import json +import time +import asyncio + +from .const import ( + DOMAIN, + LOGGER, + ZHIPUAI_URL, + CONF_TEMPERATURE, + RECOMMENDED_TEMPERATURE, + CONF_MAX_TOKENS, + RECOMMENDED_MAX_TOKENS, +) + +class ImageProcessor: + def __init__(self, hass: HomeAssistant): + self.hass = hass + + async def resize_image(self, image_data: bytes, target_width: int = 1280, is_gif: bool = False) -> str: + try: + img_byte_arr = io.BytesIO(image_data) + img = await self.hass.async_add_executor_job(Image.open, img_byte_arr) + + if is_gif and img.format == 'GIF' and getattr(img, "is_animated", False): + frames = [] + try: + for frame in range(img.n_frames): + img.seek(frame) + new_frame = await self.hass.async_add_executor_job(lambda x: x.convert('RGB'), img.copy()) + + width, height = new_frame.size + aspect_ratio = width / height + target_height = int(target_width / aspect_ratio) + if width > target_width or height > target_height: + new_frame = await self.hass.async_add_executor_job(lambda x: x.resize((target_width, target_height)), new_frame) + frames.append(new_frame) + + output = io.BytesIO() + frames[0].save(output, save_all=True, append_images=frames[1:], format='GIF', duration=img.info.get('duration', 100), loop=0) + base64_image = base64.b64encode(output.getvalue()).decode('utf-8') + return base64_image + except Exception as e: + LOGGER.error(f"GIF处理错误: {str(e)}") + img.seek(0) + + if img.mode == 'RGBA' or img.format == 'GIF': + img = await self.hass.async_add_executor_job(lambda x: x.convert('RGB'), img) + + width, height = img.size + aspect_ratio = width / height + target_height = int(target_width / aspect_ratio) + + if width > target_width or height > target_height: + img = await self.hass.async_add_executor_job(lambda x: x.resize((target_width, target_height)), img) + + img_byte_arr = io.BytesIO() + await self.hass.async_add_executor_job(lambda i, b: i.save(b, format='JPEG'), img, img_byte_arr) + base64_image = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8') + + return base64_image + except Exception as e: + LOGGER.error(f"图像处理错误: {str(e)}") + raise ServiceValidationError(f"图像处理失败: {str(e)}") + + def _similarity_score(self, previous_frame, current_frame_gray): + K1 = 0.005 + K2 = 0.015 + L = 255 + + C1 = (K1 * L) ** 2 + C2 = (K2 * L) ** 2 + + previous_frame_np = np.array(previous_frame, dtype=np.float64) + current_frame_np = np.array(current_frame_gray, dtype=np.float64) + + mu1 = np.mean(previous_frame_np) + mu2 = np.mean(current_frame_np) + + sigma1_sq = np.var(previous_frame_np) + sigma2_sq = np.var(current_frame_np) + sigma12 = np.cov(previous_frame_np.flatten(), current_frame_np.flatten())[0, 1] + + ssim = ((2 * mu1 * mu2 + C1) * (2 * sigma12 + C2)) / ((mu1**2 + mu2**2 + C1) * (sigma1_sq + sigma2_sq + C2)) + + return ssim + +IMAGE_ANALYZER_SCHEMA = vol.Schema({ + vol.Required("model", default="glm-4v-flash"): vol.In(["glm-4v-plus", "glm-4v", "glm-4v-flash"]), + vol.Required("message"): cv.string, + vol.Optional("image_file"): cv.string, + vol.Optional("image_entity"): cv.entity_id, + vol.Optional(CONF_TEMPERATURE, default=RECOMMENDED_TEMPERATURE): vol.All(vol.Coerce(float), vol.Range(min=0.1, max=1.0)), + vol.Optional(CONF_MAX_TOKENS, default=RECOMMENDED_MAX_TOKENS): vol.All(vol.Coerce(int), vol.Range(min=1, max=1024)), + vol.Optional("stream", default=False): cv.boolean, + vol.Optional("target_width", default=1280): vol.All(vol.Coerce(int), vol.Range(min=512, max=1920)), +}) + +VIDEO_ANALYZER_SCHEMA = vol.Schema( + { + vol.Required("model", default="glm-4v-plus"): cv.string, + vol.Required("message"): cv.string, + vol.Required("video_file"): cv.string, + vol.Optional(CONF_TEMPERATURE, default=RECOMMENDED_TEMPERATURE): vol.All( + vol.Coerce(float), vol.Range(min=0.1, max=1.0) + ), + vol.Optional(CONF_MAX_TOKENS, default=RECOMMENDED_MAX_TOKENS): vol.All( + vol.Coerce(int), vol.Range(min=1, max=1024) + ), + vol.Optional("stream", default=False): cv.boolean, + vol.Optional("target_width", default=1280): vol.All( + vol.Coerce(int), vol.Range(min=512, max=1920) + ), + } +) + +async def async_setup_services(hass: HomeAssistant) -> None: + image_processor = ImageProcessor(hass) + + async def handle_image_analyzer(call: ServiceCall) -> None: + try: + config_entries = hass.config_entries.async_entries(DOMAIN) + if not config_entries: + raise ValueError("未找到 ZhipuAI") + api_key = config_entries[0].data.get("api_key") + if not api_key: + raise ValueError("在配置中找不到 API 密钥") + + image_data = None + if image_file := call.data.get("image_file"): + try: + if not os.path.isabs(image_file): + image_file = os.path.join(hass.config.config_dir, image_file) + + if os.path.isdir(image_file): + raise ServiceValidationError(f"提供的路径是一个目录: {image_file}") + + if not os.path.exists(image_file): + raise ServiceValidationError(f"图片文件不存在: {image_file}") + + with open(image_file, "rb") as f: + image_data = f.read() + except IOError as e: + LOGGER.error(f"读取图片文件失败 {image_file}: {str(e)}") + raise ServiceValidationError(f"读取图片文件失败: {str(e)}") + + elif image_entity := call.data.get("image_entity"): + try: + if not image_entity.startswith("camera."): + raise ServiceValidationError(f"无效的摄像头实体ID: {image_entity}") + + if not hass.states.get(image_entity): + raise ServiceValidationError(f"摄像头实体不存在: {image_entity}") + + try: + image: CameraImage = await camera.async_get_image(hass, image_entity, timeout=10) + + if not image or not image.content: + raise ServiceValidationError(f"无法从摄像头获取图片: {image_entity}") + + image_data = image.content + base64_image = await image_processor.resize_image(image_data, target_width=call.data.get("target_width", 1280), is_gif=True) + + except (camera.CameraEntityImageError, TimeoutError) as e: + raise ServiceValidationError(f"获取摄像头图片失败: {str(e)}") + + except Exception as e: + LOGGER.error(f"从实体获取图片失败 {image_entity}: {str(e)}") + raise ServiceValidationError(f"从实体获取图片失败: {str(e)}") + + if not image_data: + raise ServiceValidationError("未提供图片数据") + + try: + base64_image = await image_processor.resize_image(image_data, target_width=call.data.get("target_width", 1280)) + except Exception as e: + LOGGER.error(f"图像处理失败: {str(e)}") + raise ServiceValidationError(f"图像处理失败: {str(e)}") + + headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {api_key}' + } + payload = { + "model": call.data["model"], + "messages": [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}" + } + }, + { + "type": "text", + "text": call.data["message"] + } + ] + } + ], + "stream": call.data.get("stream", False), + "temperature": call.data.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE), + "max_tokens": call.data.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS), + } + + try: + response = await hass.async_add_executor_job(lambda: requests.post(ZHIPUAI_URL, headers=headers, json=payload, stream=call.data.get("stream", False), timeout=30)) + response.raise_for_status() + except requests.exceptions.RequestException as e: + LOGGER.error(f"API请求失败: {str(e)}") + raise ServiceValidationError(f"API请求失败: {str(e)}") + + if call.data.get("stream", False): + event_id = f"zhipuai_response_{int(time.time())}" + + try: + hass.bus.async_fire(f"{DOMAIN}_stream_start", {"event_id": event_id, "type": "image_analysis"}) + + accumulated_text = "" + for line in response.iter_lines(): + if line: + try: + data = line.decode('utf-8').replace('data: ', '') + if data == '[DONE]': + break + + json_data = json.loads(data) + if 'choices' in json_data and len(json_data['choices']) > 0: + content = json_data['choices'][0].get('delta', {}).get('content', '') + if content: + accumulated_text += content + hass.bus.async_fire(f"{DOMAIN}_stream_token", {"event_id": event_id, "content": content, "full_content": accumulated_text}) + + except json.JSONDecodeError as e: + LOGGER.error(f"解析流式响应失败: {str(e)}") + continue + + hass.bus.async_fire(f"{DOMAIN}_stream_end", {"event_id": event_id, "full_content": accumulated_text}) + + hass.bus.async_fire(f"{DOMAIN}_response", {"type": "image_analysis", "content": accumulated_text, "success": True}) + return {"success": True, "event_id": event_id, "message": accumulated_text} + + except Exception as e: + error_msg = f"处理流式响应时出错: {str(e)}" + LOGGER.error(error_msg) + hass.bus.async_fire(f"{DOMAIN}_stream_error", {"event_id": event_id, "error": error_msg}) + return {"success": False, "message": error_msg} + else: + result = response.json() + if result.get("choices") and len(result["choices"]) > 0: + content = result["choices"][0].get("message", {}).get("content", "") + hass.bus.async_fire(f"{DOMAIN}_response", {"type": "image_analysis", "content": content, "success": True}) + return {"success": True, "message": content} + else: + error_msg = "No response from API" + return {"success": False, "message": error_msg} + except Exception as e: + error_msg = f"Image analysis failed: {str(e)}" + LOGGER.error(f"图像分析错误: {str(e)}") + return {"success": False, "message": error_msg} + + async def handle_video_analyzer(call: ServiceCall) -> None: + try: + config_entries = hass.config_entries.async_entries(DOMAIN) + if not config_entries: + raise ValueError("No ZhipuAI configuration found") + api_key = config_entries[0].data.get("api_key") + if not api_key: + raise ValueError("API key not found in configuration") + + video_file = call.data["video_file"] + if not os.path.isabs(video_file): + video_file = os.path.join(hass.config.config_dir, video_file) + + if not os.path.isfile(video_file): + LOGGER.error(f"视频文件未找到或是目录: {video_file}") + return {"success": False, "message": f"视频文件未找到或是目录: {video_file}"} + + try: + def read_video_file(): + with open(video_file, "rb") as f: + return f.read() + video_data = await hass.async_add_executor_job(read_video_file) + except IOError as e: + LOGGER.error(f"读取视频文件失败 {video_file}: {str(e)}") + return {"success": False, "message": f"读取视频文件失败: {str(e)}"} + + if call.data.get("model") != "glm-4v-plus": + LOGGER.warning("视频分析仅支持glm-4v-plus模型,强制使用glm-4v-plus") + + video_size = len(video_data) / (1024 * 1024) + if video_size > 20: + LOGGER.error(f"视频文件大小 ({video_size:.1f}MB) 超过20MB限制") + return {"success": False, "message": f"视频文件大小 ({video_size:.1f}MB) 超过20MB限制"} + + if not video_file.lower().endswith('.mp4'): + LOGGER.error("视频文件必须是MP4格式") + return {"success": False, "message": "视频文件必须是MP4格式"} + + base64_video = base64.b64encode(video_data).decode('utf-8') + + payload = { + "model": "glm-4v-plus", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "video_url", + "video_url": { + "url": base64_video + } + }, + { + "type": "text", + "text": call.data["message"] + } + ] + } + ], + "stream": call.data.get("stream", False) + } + + if CONF_TEMPERATURE in call.data: + payload["temperature"] = call.data[CONF_TEMPERATURE] + if CONF_MAX_TOKENS in call.data: + payload["max_tokens"] = call.data[CONF_MAX_TOKENS] + + headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {api_key}' + } + + try: + for attempt in range(3): + try: + response = await hass.async_add_executor_job( + lambda: requests.post( + ZHIPUAI_URL, + headers=headers, + json=payload, + stream=call.data.get("stream", False), + timeout=30 + ) + ) + response.raise_for_status() + break + except requests.exceptions.RequestException as e: + if attempt == 2: + raise + if response.status_code == 429: + await asyncio.sleep(2) + else: + raise + except requests.exceptions.RequestException as e: + LOGGER.error(f"API请求失败: {str(e)}") + return {"success": False, "message": f"API请求失败: {str(e)}"} + + if call.data.get("stream", False): + event_id = f"zhipuai_response_{int(time.time())}" + + try: + hass.bus.async_fire(f"{DOMAIN}_stream_start", {"event_id": event_id, "type": "video_analysis"}) + + accumulated_text = "" + for line in response.iter_lines(): + if line: + try: + line_text = line.decode('utf-8').strip() + if line_text.startswith('data: '): + line_text = line_text[6:] + if line_text == '[DONE]': + break + + json_data = json.loads(line_text) + if 'choices' in json_data and json_data['choices']: + choice = json_data['choices'][0] + if 'delta' in choice and 'content' in choice['delta']: + content = choice['delta']['content'] + accumulated_text += content + hass.bus.async_fire( + f"{DOMAIN}_stream_token", + { + "event_id": event_id, + "token": content, + "complete_text": accumulated_text + } + ) + except json.JSONDecodeError as e: + LOGGER.error(f"解析流式响应失败: {str(e)}") + continue + + hass.bus.async_fire(f"{DOMAIN}_stream_end", { + "event_id": event_id, + "complete_text": accumulated_text + }) + return {"success": True, "message": accumulated_text} + except Exception as e: + LOGGER.error(f"处理流式响应时出错: {str(e)}") + return {"success": False, "message": f"处理流式响应时出错: {str(e)}"} + else: + result = response.json() + if result.get("choices") and len(result["choices"]) > 0: + content = result["choices"][0].get("message", {}).get("content", "") + hass.bus.async_fire(f"{DOMAIN}_response", {"type": "video_analysis", "content": content, "success": True}) + return {"success": True, "message": content} + else: + error_msg = "No response from API" + return {"success": False, "message": error_msg} + except Exception as e: + error_msg = f"Video analysis failed: {str(e)}" + LOGGER.error(f"视频分析错误: {str(e)}") + return {"success": False, "message": error_msg} + + hass.services.async_register(DOMAIN, "image_analyzer", handle_image_analyzer, schema=IMAGE_ANALYZER_SCHEMA, supports_response=True) + + hass.services.async_register(DOMAIN, "video_analyzer", handle_video_analyzer, schema=VIDEO_ANALYZER_SCHEMA, supports_response=True) + + @callback + def async_unload_services() -> None: + hass.services.async_remove(DOMAIN, "image_analyzer") + hass.services.async_remove(DOMAIN, "video_analyzer") + + return async_unload_services diff --git a/custom_components/zhipuai/services.yaml b/custom_components/zhipuai/services.yaml new file mode 100644 index 0000000..6c88c6e --- /dev/null +++ b/custom_components/zhipuai/services.yaml @@ -0,0 +1,221 @@ +image_analyzer: + name: 图像分析 + description: 使用智谱GLM-4V模型分析图像 + fields: + model: + name: 模型 + description: '选择要使用的图像分析模型' + required: true + example: "glm-4v-flash" + default: "glm-4v-flash" + selector: + select: + options: + - "glm-4v-plus" + - "glm-4v" + - "glm-4v-flash" + message: + name: 提示词 + required: true + description: '给模型的提示词' + example: "请描述这张图片的内容" + selector: + text: + multiline: true + image_file: + name: 图片文件 + required: false + description: '本地图片路径(支持jpg、png、jpeg格式,最大5MB,最大分辨率6000x6000像素)' + example: "/config/www/tmp/front_door.jpg" + selector: + text: + multiline: false + image_entity: + name: 图片实体 + required: false + description: '要分析的图片或摄像头实体' + example: 'camera.front_door' + selector: + entity: + domain: ["image", "camera"] + multiple: false + temperature: + name: 温度 + required: false + description: '控制输出的随机性(0.1-1.0)。值越低,输出越稳定' + example: 0.8 + default: 0.8 + selector: + number: + min: 0.1 + max: 1.0 + step: 0.1 + max_tokens: + name: 最大令牌数 + required: false + description: '限制生成文本的最大长度' + example: 1024 + default: 1024 + selector: + number: + min: 1 + max: 1024 + step: 1 + stream: + name: 流式响应 + required: false + description: '是否使用流式响应(实时返回生成结果)' + example: false + default: false + selector: + boolean: {} + +video_analyzer: + name: 视频分析 + description: 使用智谱GLM-4V-Plus模型分析视频 + fields: + model: + name: 模型 + description: '视频分析模型 (仅支持 GLM-4V-Plus)' + required: false + default: "glm-4v-plus" + selector: + select: + options: + - "glm-4v-plus" + message: + name: 提示词 + required: true + description: '给模型的提示词' + example: "请描述这段视频的内容" + selector: + text: + multiline: true + video_file: + name: 视频文件 + required: true + description: '本地视频文件路径(支持mp4格式,建议时长不超过30秒)' + example: "/config/www/tmp/video.mp4" + selector: + text: + multiline: false + temperature: + name: 温度 + required: false + description: '控制输出的随机性(0.1-1.0)。值越低,输出越稳定' + example: 0.8 + default: 0.8 + selector: + number: + min: 0.1 + max: 1.0 + step: 0.1 + max_tokens: + name: 最大令牌数 + required: false + description: '限制生成文本的最大长度' + example: 1024 + default: 1024 + selector: + number: + min: 1 + max: 1024 + step: 1 + stream: + name: 流式响应 + required: false + description: '是否使用流式响应(实时返回生成结果)' + example: false + default: false + selector: + boolean: {} + +image_gen: + name: 图像生成 + description: 使用 CogView-3 模型生成图像 + fields: + prompt: + name: 图像描述 + description: 所需图像的文本描述 + required: true + example: "一只可爱的小猫咪" + selector: + text: + multiline: true + model: + name: 模型 + description: 选择要使用的模型版本 + required: false + default: cogview-3-flash + selector: + select: + options: + - label: CogView-3 Plus + value: cogview-3-plus + - label: CogView-3 + value: cogview-3 + - label: CogView-3 Flash (免费) + value: cogview-3-flash + size: + name: 图片尺寸 + description: 生成图片的尺寸大小 + required: false + default: 1024x1024 + selector: + select: + options: + - label: 1024x1024 + value: 1024x1024 + - label: 768x1344 + value: 768x1344 + - label: 864x1152 + value: 864x1152 + - label: 1344x768 + value: 1344x768 + - label: 1152x864 + value: 1152x864 + - label: 1440x720 + value: 1440x720 + - label: 720x1440 + value: 720x1440 + +web_search: + name: 联网搜索 + description: 使用智谱AI的web-search-pro工具进行联网搜索 + fields: + query: + name: 搜索内容 + description: '要搜索的内容' + required: true + example: "中国队奥运会拿了多少奖牌" + selector: + text: + multiline: true + stream: + name: 流式响应 + description: '是否使用流式响应(实时返回生成结果)' + required: false + default: false + selector: + boolean: {} + +entity_analysis: + name: 实体历史记录 + description: 获取实体的历史状态记录(如人在传感器、灯光、温度、湿度、光照度变化记录等) + fields: + entity_id: + name: 实体ID + description: 要获取历史记录的实体ID + required: true + selector: + entity: + multiple: true + days: + name: 天数 + description: 要获取的历史记录天数(1-15天) + default: 3 + selector: + number: + min: 1 + max: 15 + mode: box diff --git a/custom_components/zhipuai/translations/en.json b/custom_components/zhipuai/translations/en.json new file mode 100644 index 0000000..36434d7 --- /dev/null +++ b/custom_components/zhipuai/translations/en.json @@ -0,0 +1,96 @@ +{ + "config": { + "step": { + "user": { + "data": { + "name": "custom name", + "api_key": "API key" + }, + "description": "Get the key: [Click the link](https://open.bigmodel.cn/console/modelft/dataset)" + }, + "reauth_confirm": { + "title": "Re-verify Zhipu AI", + "description": "Your Zhipu AI API key has expired, please enter a new API key", + "data": { + "api_key": "API key" + } + }, + "reconfigure_confirm": { + "title": "Reconfigure Zhipu AI", + "description": "Please enter new configuration information", + "data": { + "api_key": "API key" + } + }, + "history": { + "title": "Historical data analysis configuration", + "description": "Select the entities to be analyzed and the number of days of historical data", + "data": { + "history_entities": "Select entity", + "history_days": "Number of days of historical data (1-3 days)" + } + } + }, + "error": { + "cannot_connect": "Unable to connect to service", + "invalid_auth": "API key error", + "unknown": "unknown error", + "cooldown_too_small": "The cooling time value {value} is too small, please set a value greater than or equal to 0!", + "cooldown_too_large": "The cooling time value {value} is too large, please set a value less than or equal to 10!", + "model_not_found": "The specified model could not be found", + "invalid_api_key": "API Key format error", + "invalid_days": "The number of historical data days must be between 1-3 days", + "no_entities": "Please select at least one entity" + }, + "abort": { + "already_configured": "The device has been configured", + "reauth_successful": "Re-authentication successful", + "reconfigure_successful": "Reconfiguration successful" + } + }, + "options": { + "step": { + "init": { + "data": { + "chat_model": "chat model", + "temperature": "temperature", + "max_tokens": "Maximum number of tokens", + "max_history_messages": "Maximum number of historical messages", + "top_p": "Top P", + "prompt": "prompt word template", + "max_tool_iterations": "Maximum number of tool iterations", + "cooldown_period": "Cooling time (seconds)", + "llm_hass_api": "Home Assistant LLM API", + "recommended": "Use recommended model settings", + "web_search": "Internet analytics search", + "history_analysis": "Entity historical data analysis" + }, + "data_description": { + "prompt": "Indicates how LLM should respond. This can be a template.", + "chat_model": "Please select the chat model you want to use. By default, please select the free universal 128K model. If you want a better experience, you can choose to support other paid models. The actual cost is not high. Please check the official website billing standards for details.", + "max_tokens": "Set the maximum number of tokens returned in the response", + "temperature": "Controls the randomness of the output (0-2)", + "top_p": "Control output diversity (0-1)", + "llm_hass_api": "Home Assistant LLM API", + "recommended": "Use recommended model settings", + "max_history_messages": "Set the maximum number of historical messages to retain. Function: Control the memory function of input content. The memory function can ensure smooth contextual dialogue. Generally, it is best to control home equipment within 5 times. It is effective for requests that cannot be processed smoothly. For other daily conversations, the threshold can be set to more than 10 times.", + "max_tool_iterations": "Set the maximum number of tool calls in a single conversation. Its function is to set the call threshold for the system LLM call request. If an error occurs, it can ensure that the system will not get stuck. Especially for the design of various small hosts with weak performance, it is recommended to set it 20-30 times.", + "cooldown_period": "Set the minimum interval between two conversation requests (0-10 seconds). Function: The request will be delayed for a period of time before being sent. It is recommended to set it within 3 seconds to ensure that the content sending request fails due to frequency factors." + } + }, + "history": { + "title": "Entity historical data analysis configuration", + "description": "Provides **entity historical data analysis** in scenarios that **Jinja2 template** (Home Assistant's template system) cannot implement, ensuring that AI understands and analyzes your device data. For example: it can be used to automatically help you analyze home security , personnel activity trajectories, daily life summary, UI text template introduction, etc.\n\n• Support **AI-assisted analysis** historical data (let AI understand and analyze your device data)\n• Provide intelligent decision support for **device management**\n• It is recommended to control within the range of **1 day historical data** for best results", + "data": { + "history_entities": "Select entity", + "history_days": "Number of days of historical data (1-15 days)" + } + } + } + }, + "exceptions": { + "invalid_config_entry": { + "message": "The provided configuration entry is invalid. What you get is {config_entry}" + } + } +} \ No newline at end of file diff --git a/custom_components/zhipuai/translations/zh-Hans.json b/custom_components/zhipuai/translations/zh-Hans.json new file mode 100644 index 0000000..ca8c866 --- /dev/null +++ b/custom_components/zhipuai/translations/zh-Hans.json @@ -0,0 +1,96 @@ +{ + "config": { + "step": { + "user": { + "data": { + "name": "自定义名称", + "api_key": "API 密钥" + }, + "description": "获取密钥:[点击链接](https://open.bigmodel.cn/console/modelft/dataset)" + }, + "reauth_confirm": { + "title": "重新验证 智谱AI", + "description": "您的 智谱AI API密钥已失效,请输入新的 API 密钥", + "data": { + "api_key": "API 密钥" + } + }, + "reconfigure_confirm": { + "title": "重新配置 智谱AI", + "description": "请输入新的配置信息", + "data": { + "api_key": "API 密钥" + } + }, + "history": { + "title": "历史数据分析配置", + "description": "选择需要分析的实体和历史数据天数", + "data": { + "history_entities": "选择实体", + "history_days": "历史数据天数 (1-3天)" + } + } + }, + "error": { + "cannot_connect": "无法连接到服务", + "invalid_auth": "API密钥错误", + "unknown": "未知错误", + "cooldown_too_small": "冷却时间值 {value} 太小,请设置大于等于 0 的值!", + "cooldown_too_large": "冷却时间值 {value} 太大,请设置小于等于 10 的值!", + "model_not_found": "找不到指定的模型", + "invalid_api_key": "API Key 格式错误", + "invalid_days": "历史数据天数必须在1-3天之间", + "no_entities": "请至少选择一个实体" + }, + "abort": { + "already_configured": "设备已经配置", + "reauth_successful": "重新认证成功", + "reconfigure_successful": "重新配置成功" + } + }, + "options": { + "step": { + "init": { + "data": { + "chat_model": "聊天模型", + "temperature": "温度", + "max_tokens": "最大令牌数", + "max_history_messages": "最大历史消息数", + "top_p": "Top P", + "prompt": "提示词模板", + "max_tool_iterations": "最大工具迭代次数", + "cooldown_period": "冷却时间(秒)", + "llm_hass_api": "Home Assistant LLM API", + "recommended": "使用推荐的模型设置", + "web_search": "互联网分析搜索", + "history_analysis": "实体历史数据分析" + }, + "data_description": { + "prompt": "指示 LLM 应如何响应。这可以是一个模板。", + "chat_model": "请选择要使用的聊天模型,默认请选择免费通用128K模型,如需更好体验可选择支持其他付费模型,实际费用不高,具体请查看官网计费标准。", + "max_tokens": "设置响应中返回的最大令牌数", + "temperature": "控制输出的随机性(0-2)", + "top_p": "控制输出多样性(0-1)", + "llm_hass_api": "Home Assistant LLM API", + "recommended": "使用推荐的模型设置", + "max_history_messages": "设置要保留的最大历史消息数。功能:控制输入内容的记忆功能,记忆功能可以保证上下文对话顺畅,一般控制家居设备最好控制在5次以内,对请求不能顺利进行有效,其他日常对话可以设置阈值在10次以上。", + "max_tool_iterations": "设置单次对话中的最大工具调用次数。其功能是对系统LLM调用请求设置调用阈值,如果出错可以保证系统不会卡死,尤其是对各种性能较弱的小主机的设计,建议设置20-30次。", + "cooldown_period": "设置两次对话请求的最小间隔时间(0-10秒)。作用:请求会延迟一段时间再发送,建议设置在3秒以内,保证因为频率因素导致内容发送请求失败。" + } + }, + "history": { + "title": "实体历史数据分析 配置", + "description": "在**Jinja2模版**(Home Assistant的模板系统)无法实现的场景下提供**实体历史数据分析**,保证AI理解并分析您的设备数据,举例:可以用于自动化帮您分析家中安防、人员活动轨迹,日常生活总结,UI文本模版介绍等。\n\n• 支持**AI辅助分析**历史数据(让AI理解并分析您的设备数据)\n• 为**设备管理**提供智能决策支持\n• 建议控制在**1天历史数据**范围内以获得最佳效果", + "data": { + "history_entities": "选择实体", + "history_days": "历史数据天数 (1-15天)" + } + } + } + }, + "exceptions": { + "invalid_config_entry": { + "message": "提供的配置条目无效。得到的是 {config_entry}" + } + } +} diff --git a/custom_components/zhipuai/web_search.py b/custom_components/zhipuai/web_search.py new file mode 100644 index 0000000..258c27d --- /dev/null +++ b/custom_components/zhipuai/web_search.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +import json +import logging +import uuid +import time +from typing import Any, Dict + +import voluptuous as vol +import requests +from homeassistant.core import HomeAssistant, ServiceCall, callback +from homeassistant.exceptions import HomeAssistantError, ServiceValidationError +from homeassistant.helpers import config_validation as cv + +from .const import ( + DOMAIN, + LOGGER, + ZHIPUAI_WEB_SEARCH_URL, +) + +WEB_SEARCH_SCHEMA = vol.Schema({ + vol.Required("query"): cv.string, + vol.Optional("stream", default=False): cv.boolean, +}) + +async def async_setup_web_search(hass: HomeAssistant) -> None: + + async def handle_web_search(call: ServiceCall) -> None: + try: + config_entries = hass.config_entries.async_entries(DOMAIN) + if not config_entries: + raise ValueError("未找到 ZhipuAI 配置") + api_key = config_entries[0].data.get("api_key") + if not api_key: + raise ValueError("在配置中找不到 API 密钥") + + query = call.data["query"] + stream = call.data.get("stream", False) + request_id = str(uuid.uuid4()) + + messages = [{"role": "user", "content": query}] + payload = { + "request_id": request_id, + "tool": "web-search-pro", + "stream": stream, + "messages": messages + } + + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } + + try: + response = await hass.async_add_executor_job( + lambda: requests.post( + ZHIPUAI_WEB_SEARCH_URL, + headers=headers, + json=payload, + stream=stream, + timeout=300 + ) + ) + response.raise_for_status() + except requests.exceptions.RequestException as e: + LOGGER.error(f"API请求失败: {str(e)}") + raise ServiceValidationError(f"API请求失败: {str(e)}") + + if stream: + event_id = f"zhipuai_response_{int(time.time())}" + + try: + hass.bus.async_fire(f"{DOMAIN}_stream_start", { + "event_id": event_id, + "type": "web_search" + }) + + accumulated_text = "" + for line in response.iter_lines(): + if line: + try: + line_text = line.decode('utf-8').strip() + if line_text.startswith('data: '): + line_text = line_text[6:] + if line_text == '[DONE]': + break + + json_data = json.loads(line_text) + if 'choices' in json_data and json_data['choices']: + choice = json_data['choices'][0] + if 'message' in choice and 'tool_calls' in choice['message']: + tool_calls = choice['message']['tool_calls'] + for tool_call in tool_calls: + if tool_call.get('type') == 'search_result': + search_results = tool_call.get('search_result', []) + for result in search_results: + content = result.get('content', '') + if content: + accumulated_text += content + "\n" + hass.bus.async_fire( + f"{DOMAIN}_stream_token", + { + "event_id": event_id, + "content": content, + "full_content": accumulated_text + } + ) + except json.JSONDecodeError as e: + LOGGER.error(f"解析流式响应失败: {str(e)}") + continue + + hass.bus.async_fire(f"{DOMAIN}_stream_end", { + "event_id": event_id, + "full_content": accumulated_text + }) + + hass.bus.async_fire(f"{DOMAIN}_response", { + "type": "web_search", + "content": accumulated_text, + "success": True + }) + return {"success": True, "event_id": event_id, "message": accumulated_text} + + except Exception as e: + error_msg = f"处理流式响应时出错: {str(e)}" + LOGGER.error(error_msg) + hass.bus.async_fire(f"{DOMAIN}_stream_error", { + "event_id": event_id, + "error": error_msg + }) + return {"success": False, "message": error_msg} + else: + result = response.json() + content = "" + if result.get("choices") and result["choices"][0].get("message", {}).get("tool_calls"): + tool_calls = result["choices"][0]["message"]["tool_calls"] + for tool_call in tool_calls: + if tool_call.get("type") == "search_result": + search_results = tool_call.get("search_result", []) + for result in search_results: + if result_content := result.get("content"): + content += result_content + "\n" + + if content: + hass.bus.async_fire(f"{DOMAIN}_response", { + "type": "web_search", + "content": content, + "success": True + }) + return {"success": True, "message": content} + else: + error_msg = "未从API获取到搜索结果" + return {"success": False, "message": error_msg} + + except Exception as e: + error_msg = f"Web search failed: {str(e)}" + LOGGER.error(f"网络搜索错误: {str(e)}") + return {"success": False, "message": error_msg} + + hass.services.async_register( + DOMAIN, + "web_search", + handle_web_search, + schema=WEB_SEARCH_SCHEMA, + supports_response=True + ) + + @callback + def async_unload_services() -> None: + hass.services.async_remove(DOMAIN, "web_search") + + return async_unload_services diff --git a/hacs.json b/hacs.json new file mode 100644 index 0000000..e59be77 --- /dev/null +++ b/hacs.json @@ -0,0 +1,5 @@ +{ + "name": "智谱清言", + "render_readme": true, + "homeassistant": "2024.8.0" +}