Skip to content

Commit

Permalink
修补01.16
Browse files Browse the repository at this point in the history
  • Loading branch information
knoop7 authored Jan 16, 2025
1 parent 68dedcb commit 4780f7e
Show file tree
Hide file tree
Showing 7 changed files with 472 additions and 115 deletions.
10 changes: 5 additions & 5 deletions custom_components/zhipuai/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,16 +336,16 @@ def zhipuai_config_option_schema(
CONF_RECOMMENDED,
default=options.get(CONF_RECOMMENDED, False)
): bool,
vol.Optional(
CONF_MAX_HISTORY_MESSAGES,
description={"suggested_value": options.get(CONF_MAX_HISTORY_MESSAGES)},
default=RECOMMENDED_MAX_HISTORY_MESSAGES,
): int,
vol.Optional(
CONF_CHAT_MODEL,
description={"suggested_value": options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)},
default=RECOMMENDED_CHAT_MODEL,
): SelectSelector(SelectSelectorConfig(options=ZHIPUAI_MODELS)),
vol.Optional(
CONF_MAX_HISTORY_MESSAGES,
description={"suggested_value": options.get(CONF_MAX_HISTORY_MESSAGES)},
default=RECOMMENDED_MAX_HISTORY_MESSAGES,
): int,
vol.Optional(
CONF_MAX_TOOL_ITERATIONS,
description={"suggested_value": options.get(CONF_MAX_TOOL_ITERATIONS, DEFAULT_MAX_TOOL_ITERATIONS)},
Expand Down
2 changes: 1 addition & 1 deletion custom_components/zhipuai/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
RECOMMENDED_MAX_HISTORY_MESSAGES = 5

CONF_MAX_TOOL_ITERATIONS = "max_tool_iterations"
DEFAULT_MAX_TOOL_ITERATIONS = 5
DEFAULT_MAX_TOOL_ITERATIONS = 20
CONF_COOLDOWN_PERIOD = "cooldown_period"
DEFAULT_COOLDOWN_PERIOD = 1

Expand Down
157 changes: 83 additions & 74 deletions custom_components/zhipuai/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,28 +94,15 @@ def is_service_call(user_input: str) -> bool:
"press": ["按", "按下", "点击"],
"select": ["选择", "下一个", "上一个", "第一个", "最后一个"],
"trigger": ["触发", "调用"],
"media": ["暂停", "继续播放", "播放", "停止", "下一首", "下一曲", "下一个", "切歌", "换歌","上一首", "上一曲", "上一个", "返回上一首", "音量"],
"climate": ["制冷", "制热", "风速", "模式", "调温", "调到", "设置",
"空调", "冷气", "暖气", "冷风", "暖风", "自动模式", "除湿", "送风",
"高档", "低档", "高速", "低速", "自动高", "自动低", "强劲", "自动"]
"number": ["数字", "数值"],
"media": ["暂停", "继续播放", "播放", "停止", "下一首", "下一曲", "下一个", "切歌", "换歌","上一首", "上一曲", "上一个", "返回上一首", "音量"]
}
}

has_pattern = bool(user_input and (
return bool(user_input and (
any(k in user_input for k in patterns["control"]) or
any(k in user_input for action in patterns["action"].values() for k in (action if isinstance(action, list) else []))
))

has_climate = (
"空调" in user_input and (
bool(re.search(r"\d+", user_input)) or
any(k in user_input for k in patterns["action"]["climate"])
)
)

has_entity = bool(re.search(r"([\w_]+\.[\w_]+)", user_input))

return has_pattern or has_climate or has_entity

def extract_service_info(user_input: str, hass: HomeAssistant) -> Optional[Dict[str, Any]]:
def find_entity(domain: str, text: str) -> Optional[str]:
Expand Down Expand Up @@ -165,48 +152,8 @@ def clean_text(text: str, patterns: List[str]) -> str:
for domain, service in [("script", "turn_on"), ("automation", "trigger"), ("scene", "turn_on")]
if (entity_id := find_entity(domain, name))), None)

climate_patterns = {"温度": "set_temperature", "制冷": "set_hvac_mode", "制热": "set_hvac_mode",
"风速": "set_fan_mode", "模式": "set_hvac_mode", "湿度": "set_humidity",
"调温": "set_temperature", "调到": "set_temperature", "设置": "set_hvac_mode",
"度": "set_temperature"}

if entity_id := find_entity("climate", user_input):
temperature_match = re.search(r'(\d+)(?:\s*度)?', user_input)
if temperature_match and ("温度" in user_input or "调温" in user_input or "调到" in user_input or "度" in user_input):
temperature = int(temperature_match.group(1))
current_temp = hass.states.get(entity_id).attributes.get('current_temperature')
hvac_mode = "cool" if current_temp > temperature else "heat" if current_temp is not None else None
return {"domain": "climate", "service": "set_temperature",
"data": {"entity_id": entity_id, "temperature": temperature, "hvac_mode": hvac_mode}} if hvac_mode else {"domain": "climate", "service": "set_temperature",
"data": {"entity_id": entity_id, "temperature": temperature}}

for pattern, service in climate_patterns.items():
if pattern in user_input.lower():
if service == "set_temperature":
temperature_match = re.search(r'(\d+)', user_input)
if temperature_match:
temperature = int(temperature_match.group(1))
current_temp = hass.states.get(entity_id).attributes.get('current_temperature')
hvac_mode = "cool" if current_temp > temperature else "heat" if current_temp is not None else None
return {"domain": "climate", "service": service,
"data": {"entity_id": entity_id, "temperature": temperature, "hvac_mode": hvac_mode}} if hvac_mode else {"domain": "climate", "service": service,
"data": {"entity_id": entity_id, "temperature": temperature}}
elif service == "set_hvac_mode":
hvac_mode_map = {"制冷": "cool", "制热": "heat", "自动": "auto", "除湿": "dry",
"送风": "fan_only", "关闭": "off", "停止": "off"}
return {"domain": "climate", "service": service,
"data": {"entity_id": entity_id, "hvac_mode": next((mode for cn, mode in hvac_mode_map.items() if cn in user_input), "auto")}}
elif service == "set_fan_mode":
fan_mode_map = {"高档": "on_high", "高速": "on_high", "强劲": "on_high",
"低档": "on_low", "低速": "on_low", "自动高": "auto_high",
"自动高档": "auto_high", "自动低": "auto_low", "自动低档": "auto_low",
"关闭": "off", "停止": "off"}
return {"domain": "climate", "service": service,
"data": {"entity_id": entity_id, "fan_mode": next((mode for cn, mode in fan_mode_map.items() if cn in user_input), "auto_low")}}
elif service == "set_humidity":
humidity_match = re.search(r'(\d+)', user_input)
return {"domain": "climate", "service": service,
"data": {"entity_id": entity_id, "humidity": int(humidity_match.group(1))}} if humidity_match else None
if any(p in user_input for p in ["数字", "数值"]) and (number_match := re.search(r'\d+(?:\.\d+)?', user_input)) and (entity_id := find_entity("number", clean_text(user_input, ["数字", "数值"]))):
return {"domain": "number", "service": "set_value", "data": {"entity_id": entity_id, "value": number_match.group(0)}}

return None

Expand All @@ -223,14 +170,14 @@ def __init__(self, entry: ConfigEntry, hass: HomeAssistant) -> None:
self._attr_device_info = dr.DeviceInfo(
identifiers={(DOMAIN, entry.entry_id)},
name=entry.title,
manufacturer="智谱清言",
model="ChatGLM Pro",
manufacturer="北京智谱华章科技",
model="ChatGLM AI",
entry_type=dr.DeviceEntryType.SERVICE,
)
if self.entry.options.get(CONF_LLM_HASS_API) and self.entry.options.get(CONF_LLM_HASS_API) != "none":
self._attr_supported_features = conversation.ConversationEntityFeature.CONTROL
self.last_request_time = 0
self.max_tool_iterations = min(entry.options.get(CONF_MAX_TOOL_ITERATIONS, DEFAULT_MAX_TOOL_ITERATIONS), 5)
self.max_tool_iterations = min(entry.options.get(CONF_MAX_TOOL_ITERATIONS, DEFAULT_MAX_TOOL_ITERATIONS), 30)
self.cooldown_period = entry.options.get(CONF_COOLDOWN_PERIOD, DEFAULT_COOLDOWN_PERIOD)
self.llm_api = None
self.intent_handler = IntentHandler(hass)
Expand Down Expand Up @@ -315,7 +262,7 @@ async def async_process(self, user_input: conversation.ConversationInput) -> con
try:
if options.get(CONF_LLM_HASS_API) and options[CONF_LLM_HASS_API] != "none":
self.llm_api = await llm.async_get_api(self.hass, options[CONF_LLM_HASS_API], llm_context)
tools = [_format_tool(tool, self.llm_api.custom_serializer) for tool in self.llm_api.tools][:8]
tools = [_format_tool(tool, self.llm_api.custom_serializer) for tool in self.llm_api.tools]

if not options.get(CONF_WEB_SEARCH, DEFAULT_WEB_SEARCH):
if any(term in user_input.text.lower() for term in ["联网", "查询", "网页", "search"]):
Expand Down Expand Up @@ -392,7 +339,7 @@ async def async_process(self, user_input: conversation.ConversationInput) -> con
start_time = end_time - timedelta(days=days)

history_text = []
history_text.append(f"\n以下是询问者所关注的实体的历史数据分析({days}天内):")
history_text.append("以下是询问者所关注的实体的历史数据分析({}天内):".format(days))

instance = get_instance(self.hass)
history_data = await instance.async_add_executor_job(
Expand All @@ -409,13 +356,13 @@ async def async_process(self, user_input: conversation.ConversationInput) -> con
for entity_id in entities:
state = self.hass.states.get(entity_id)
if state is None or entity_id not in history_data or not history_data[entity_id]:
history_text.append(f"\n{entity_id} (当前状态):")
history_text.append("{} (当前状态):".format(entity_id))
history_text.append(
f"- {state.state if state else 'unknown'} ({state.last_updated.astimezone().strftime('%m-%d %H:%M:%S') if state else 'unknown'})"
"- {} ({})".format(state.state if state else 'unknown', state.last_updated.astimezone().strftime('%m-%d %H:%M:%S') if state else 'unknown')
)
else:
states = history_data[entity_id]
history_text.append(f"\n{entity_id} (历史状态变化):")
history_text.append("{} (历史状态变化):".format(entity_id))
last_state_text = None
last_time = None
for state in states:
Expand All @@ -427,13 +374,18 @@ async def async_process(self, user_input: conversation.ConversationInput) -> con
if last_time and (current_time - last_time).total_seconds() < interval_minutes * 60:
continue

state_text = f"- {state.state} ({current_time.strftime('%m-%d %H:%M:%S')})"
state_text = "- {} ({})".format(state.state, current_time.strftime('%m-%d %H:%M:%S'))
if state_text != last_state_text:
history_text.append(state_text)
last_state_text = state_text
last_time = current_time
if len(history_text) > 1:
prompt_parts.append("\n".join(history_text))
history_text_str = "\n".join(history_text).strip()
if history_text_str:
prompt_parts.append({
"type": "history_analysis",
"content": history_text
})

except Exception as err:
LOGGER.warning(f"获取历史数据时出错: {err}")
Expand All @@ -445,13 +397,70 @@ async def async_process(self, user_input: conversation.ConversationInput) -> con
return conversation.ConversationResult(response=intent_response, conversation_id=conversation_id)

if self.llm_api:
prompt_parts.append(self.llm_api.api_prompt)
api_instructions = [line.strip() if line.startswith(" ") else line for line in self.llm_api.api_prompt.split('\n') if line.strip()]
prompt_parts.append({
"type": "api_instructions",
"content": api_instructions
})

base_prompt = prompt_parts[0]
base_instructions = [line.strip() if line.startswith(" ") else line for line in base_prompt.split('\n') if line.strip()]
prompt_parts[0] = {
"type": "system_instructions",
"content": base_instructions
}

prompt = "\n".join(prompt_parts)
LOGGER.info("提示部件: %s", prompt_parts)
def format_modes(modes):
if not modes:
return []
if isinstance(modes, str):
return [m.strip() for m in modes.split(',')]
return [str(mode) for mode in modes]

climate_entities = [state for state in self.hass.states.async_all() if state.domain == "climate"]
if climate_entities:
content = []
for entity in climate_entities:
attrs = entity.attributes
hvac_modes = format_modes(attrs.get('hvac_modes', []))
fan_modes = format_modes(attrs.get('fan_modes', []))
swing_modes = format_modes(attrs.get('swing_modes', []))

entity_info = [
f"- names: {attrs.get('friendly_name', entity.entity_id)}",
f"domain: climate",
f"state: {entity.state}",
"attributes:",
f"current_temperature: {attrs.get('current_temperature')}",
f"temperature: {attrs.get('temperature')}",
f"min_temp: {attrs.get('min_temp')}",
f"max_temp: {attrs.get('max_temp')}",
f"target_temp_step: {attrs.get('target_temp_step')}",
f"hvac_modes: {hvac_modes}",
f"fan_modes: {fan_modes}",
f"swing_modes: {swing_modes}",
f"hvac_action: {attrs.get('hvac_action')}",
f"fan_mode: {attrs.get('fan_mode')}",
f"swing_mode: {attrs.get('swing_mode')}",
f"current_humidity: {attrs.get('current_humidity')}",
f"humidity: {attrs.get('humidity')}"
]
content.extend(entity_info)

prompt_parts.append({"type": "climate_status", "content": content})

prompt_json = json.dumps(prompt_parts, ensure_ascii=False, separators=(',', ':'))
LOGGER.info("提示部件: %s", prompt_json)

all_lines = []
for part in prompt_parts:
if isinstance(part["content"], list):
all_lines.extend([line.strip() if line.startswith(" ") else line for line in part["content"]])
else:
all_lines.extend([line.strip() if line.startswith(" ") else line for line in part["content"].split('\n') if line.strip()])

messages = [
ChatCompletionMessageParam(role="system", content=prompt),
ChatCompletionMessageParam(role="system", content="\n".join(all_lines)),
*(messages if use_history else []),
ChatCompletionMessageParam(role="user", content=user_input.text),
]
Expand All @@ -467,7 +476,7 @@ async def async_process(self, user_input: conversation.ConversationInput) -> con
"max_tokens": min(options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS), 4096),
"top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
"temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
"request_id": conversation_id,
"request_id": conversation_id
}
if tools:
payload["tools"] = tools
Expand Down Expand Up @@ -497,7 +506,7 @@ async def async_process(self, user_input: conversation.ConversationInput) -> con
if isinstance(tool_response, dict) and "error" in tool_response:
raise Exception(tool_response["error"])

formatted_response = json.dumps(tool_response)
formatted_response = json.dumps(tool_response, ensure_ascii=False) if isinstance(tool_response, (dict, list)) else str(tool_response)
messages.append(
ChatCompletionMessageParam(
role="tool",
Expand Down
Loading

0 comments on commit 4780f7e

Please sign in to comment.