diff --git a/metagpt/ext/ai_writer/app_demo.py b/metagpt/ext/ai_writer/app_demo.py new file mode 100644 index 000000000..418957c6b --- /dev/null +++ b/metagpt/ext/ai_writer/app_demo.py @@ -0,0 +1,364 @@ +import gradio as gr +import json,os ,shutil +from typing import Optional +from pathlib import Path +from pydantic import Field +from metagpt.const import DATA_PATH +from metagpt.schema import Message, Plan +from metagpt.actions import SearchAndSummarize, UserRequirement +from metagpt.ext.ai_writer.write_planner import DocumentPlan, WritePlanner +from metagpt.ext.ai_writer.write_refine import WriteGuide,Refine,WriteSubsection,Clean +from metagpt.ext.ai_writer.document import build_engine +from metagpt.ext.ai_writer.utils import WriteOutFile + +class DocumentGenerator(WritePlanner): + """ + 继承自WritePlanner的类,用于生成文档。 + """ + store: Optional[object] = Field(default=None, exclude=True) + def add_file_button(self, topic, history, add_file_button): + ref_dir = DATA_PATH / f"ai_writer/ref/{topic}" + persist_dir = DATA_PATH / f"persist/{topic}" + + if not os.path.isdir(ref_dir): + os.makedirs(ref_dir) + for file in add_file_button: + shutil.move(file, ref_dir) + history.append(('完成解析', file)) + + model ='model/bge-large-zh-v1.5' + self.store = build_engine(ref_dir,persist_dir,model) + shutil.rmtree(ref_dir) + return history + + async def generate_outline(self): + """ + 异步生成文档大纲。 + """ + context = self.get_useful_memories() + response = await DocumentPlan().run(context) + self.working_memory.add(Message(content=response, role="assistant", cause_by=DocumentPlan)) + return response + + async def gen_outline_button(self, requirements): + self.plan = Plan(goal=requirements.strip()) + response = await self.generate_outline() + return [(requirements, response)] + + + async def submit_outline_button(self, user_input, conversation_history): + self.working_memory.add(Message(content=user_input, role="user")) + response = await self.generate_outline() + conversation_history.append((user_input, response)) + return "Outline updated", conversation_history + + def confirm_outline_button(self, requirements, history, outline): + self.plan = Plan(goal=requirements.strip()) + if not outline: + outline = history[-1][-1] if history else '' + '''根据大纲建文档目录树状结构''' + rsp = self.post_process_chapter_id_or_name(outline) + self.titlehierarchy = self.process_response_and_build_hierarchy(rsp=rsp) + return outline + + + def get_name_and_subheading(self,id): + obj = self.titlehierarchy.get_chapter_obj_by_id(id) + chapter_name = obj.name + subheadings = self.titlehierarchy.get_subheadings_by_prefix(id) + return chapter_name, '\n'.join(subheadings) + + async def retrieve_button(self,chapter_name): + contexts = '请上传关联文件' + if self.store: + contexts = await self.store.aretrieve(chapter_name) + contexts = '\n\n'.join([x.text for x in contexts]) + return contexts + + async def retrieve_clean(self,title, contexts): + context = '' + if self.store: + context = await Clean().run(title = title, contexts = contexts) + return context + + + + async def gen_guide(self, chapter_id, chapter_name, subheadings, history): + if subheadings: + contexts = await self.retrieve_button(chapter_name) + guideline = await WriteGuide().run( + user_requirement= self.plan.goal, + chapter_name = chapter_name, + subheadings = ','.join([section for section in subheadings]), + contexts = contexts + ) + history.append((f'{chapter_id} {chapter_name}', guideline)) + self.titlehierarchy.set_content_by_id(chapter_id, guideline) + yield history + + for subheading in subheadings: + chapter_id, name = subheading.split(' ') + subtitle = self.titlehierarchy.get_subheadings_by_prefix(chapter_id) + async for output in self.gen_guide(chapter_id, name, subtitle, history): + yield output + + async def gen_guide_button(self, chapter_id, history): + history = [] + subheadings = self.titlehierarchy.get_subheadings_by_prefix(chapter_id) + chapter_name = self.titlehierarchy.get_chapter_obj_by_id(chapter_id).name + async for output in self.gen_guide(chapter_id, chapter_name, subheadings, history): + yield output + + async def write_paragraph(self, parent_id, child_id, chapter_name, subheadings, history): + if subheadings: + guidelines = self.titlehierarchy.get_chapter_obj_by_id(parent_id).content + history.append((f'{child_id} {chapter_name}', guidelines)) + yield history + + for subheading in subheadings: + child_id, chapter_name = subheading.split(' ') + child_heading = self.titlehierarchy.get_subheadings_by_prefix(child_id) + async for output in self.write_paragraph(parent_id, child_id, chapter_name, child_heading, history): + yield output + else: + contexts = await self.retrieve_button(chapter_name) + guidelines = self.titlehierarchy.get_chapter_obj_by_id(parent_id).content + gen_paragraph = await WriteSubsection().run( + subsection = chapter_name, + contexts = f'{guidelines}\n\n# Reference: \n```{contexts}```' + ) + history.append((f'{child_id} {chapter_name}', gen_paragraph)) + yield history + + + async def write_paragraph_button(self, chapter_id, history): + history = [] + subheadings = self.titlehierarchy.get_subheadings_by_prefix(chapter_id) + chapter_name = self.titlehierarchy.get_chapter_obj_by_id(chapter_id).name + async for output in self.write_paragraph(chapter_id, chapter_id, chapter_name, subheadings, history): + yield output + + + + async def refine_button(self, revise_id, instrution, addition_context, revise_text): + obj = self.titlehierarchy.get_chapter_obj_by_id(revise_id.lstrip()) + chapter_name , pre_result = obj.name , obj.content + cur_result = await Refine().run( + user_requirement = instrution, + original_query = chapter_name, + respones = '\n\n'.join([pre_result, revise_text]) , + contexts = addition_context + ) + return cur_result + + async def web_button(self, revise_id, instrution): + chapter_name = self.titlehierarchy.get_chapter_obj_by_id(revise_id.lstrip()).name + prompt = instrution if instrution else chapter_name + message = [Message(content= prompt, role="user", cause_by = UserRequirement)] + addition_context = await SearchAndSummarize().run(message) + return addition_context + + def commit_button(self, revise_id, revise_text, chatbot): + self.titlehierarchy.set_content_by_id(revise_id, revise_text) + new_chatbot = [] + for title, content in chatbot: + cur_id, _ = title.split(' ') + if cur_id == revise_id: + new_chatbot.append((title, revise_text)) + else: + new_chatbot.append((title, content)) + return new_chatbot + + def download_button(self,topic): + output_path = DATA_PATH / f"ai_writer/outputs/{topic}.docx" + if not output_path.exists(): + + WriteOutFile.write_word_file(topic = topic, + tasks= self.titlehierarchy.traverse_and_output(), + output_path = output_path + ) + + return gr.DownloadButton(label=f"Download", value= output_path, visible=True) + + @staticmethod + def create_directory_structure_botton(data, prefix="", is_last=True): + data = json.loads(data) + chatbot = '' + for index, item in enumerate(data): + chapter_name = item["chapter_name"] + subheadings = item.get("subheadings", []) + # Determine the prefix based on whether it's the last item in its level + current_prefix = f"{prefix}{'└── ' if is_last else '├── '}" + # Print the main directory + chatbot += (f"{current_prefix}{chapter_name}\n") + # Update the prefix for subdirectories + next_prefix = f"{prefix}{' ' if is_last else '│ '}" + # If there are subheadings, handle them differently to ensure correct indentation + if subheadings: + chatbot += (f"{next_prefix}├──\n") + for sub_index, subheading in enumerate(subheadings): + sub_current_prefix = f"{next_prefix}{' ' if sub_index == len(subheadings) - 1 else '│ '}" + chatbot += (f"{sub_current_prefix}├── {subheading}\n") + return [('',chatbot)] + + +doc_gen = DocumentGenerator() + +async def main(): + with gr.Blocks(css="") as demo: + gr.Markdown("## AI 智能文档写作 Demo") + with gr.Row(): + with gr.Column(scale=0, elem_id="row1"): + with gr.Tab("开始"): + topic = gr.Textbox( + "产业数字化对中国出口隐含碳的影响", + label="话题", + lines=7, + interactive=True, + ) + user_requriments = gr.Textbox( + "写一个完整、连贯的《产业数字化对中国出口隐含碳的影响》文档, 确保文字精确、逻辑清晰,并保持专业和客观的写作风格,中文书写", + label="用户需求", + lines=9, + interactive=True, + ) + add_file_button = gr.UploadButton("📁 Upload (上传文件)",file_count="multiple") + gen_outline_button = gr.Button("生成大纲") + + + with gr.Tab("大纲"): + outline_box = gr.Textbox(label="大纲", + lines=16, + interactive=True) + + user_input = gr.Textbox('eg:请帮我新增章节', + lines=2, + label='大纲修订(增删改)') + submit_outline_button = gr.Button("提交") + confirm_outline_button = gr.Button("确认") + + + with gr.Tab("生成段落"): + chapter_id = gr.Textbox('1',label="chapter_id", lines=1, interactive=True) + chapter_name = gr.Textbox('',label="大章节名称", lines = 1, interactive=False) + chapter_subname = gr.Textbox('',label="小节名称", lines=2, interactive=False) + retrieve_bot = gr.Textbox('',label="资源检索", lines=5, interactive=False) + retrieve_button = gr.Button("资源检索") + gen_guide_button = gr.Button("生成指南") + write_paragraph_button = gr.Button("生成段落") + + + with gr.Tab("功能区"): + instrution = gr.Textbox(label="润色指令", lines=4, interactive=True) + addition_context = gr.Textbox(label="临时新增内容", lines=10, interactive=True) + refine_button = gr.Button("润色") + web_button = gr.Button("联网") + download_button = gr.DownloadButton("下载", + visible=True,) + + with gr.Column(scale=3, elem_id="row2"): + chatbot = gr.Chatbot(label='output', height=690) + + with gr.Column(scale=0, elem_id="row3"): + revise_text = gr.Textbox( + label="修订", lines=30, interactive=True, show_copy_button=True + ) + commit_button = gr.Button("确认") + + add_file_button.upload(doc_gen.add_file_button, + inputs=[topic, chatbot, add_file_button], + outputs = [chatbot], + show_progress = True + ) + + + gen_outline_button.click(doc_gen.gen_outline_button, + inputs=[user_requriments], + outputs=[chatbot], + show_progress = True + ) + + submit_outline_button.click( + doc_gen.submit_outline_button, + inputs=[user_input, chatbot], + outputs=[user_input, chatbot] + ) + + + confirm_outline_button.click( + doc_gen.confirm_outline_button, + inputs=[user_requriments, chatbot, outline_box], + outputs=[outline_box] + ).then( + doc_gen.create_directory_structure_botton, + inputs=[outline_box], + outputs=[chatbot] + ) + + retrieve_button.click( + doc_gen.get_name_and_subheading, + inputs = [chapter_id], + outputs= [chapter_name,chapter_subname] + ).then( + doc_gen.retrieve_button, + inputs = [chapter_name], + outputs= [retrieve_bot] + ).then( + doc_gen.retrieve_clean, + inputs = [chapter_name,retrieve_bot], + outputs= [revise_text] + ) + + + gen_guide_button.click( + doc_gen.get_name_and_subheading, + inputs = [chapter_id], + outputs= [chapter_name,chapter_subname] + ).then( + doc_gen.gen_guide_button, + inputs=[chapter_id, chatbot], + outputs=[chatbot] + ) + + write_paragraph_button.click( + doc_gen.write_paragraph_button, + inputs=[chapter_id, chatbot], + outputs=[chatbot] + ) + + refine_button.click(doc_gen.refine_button, + inputs= [chapter_id, instrution,addition_context,revise_text], + outputs=[revise_text], + show_progress = True + ) + + web_button.click(doc_gen.web_button, + inputs= [chapter_id, instrution], + outputs=[addition_context], + show_progress = True + ) + + commit_button.click(doc_gen.commit_button, + inputs=[chapter_id, revise_text, chatbot], + outputs=[chatbot] + ) + + + download_button.click( + doc_gen.download_button, + inputs=[topic], + outputs=download_button, + show_progress = True + ) + + + demo.queue().launch( + share=True, + inbrowser=False, + server_port=8888, + server_name="0.0.0.0" + ) + +if __name__ == "__main__": + import asyncio + asyncio.run(main()) diff --git a/metagpt/ext/ai_writer/document.py b/metagpt/ext/ai_writer/document.py new file mode 100644 index 000000000..c59c2ef18 --- /dev/null +++ b/metagpt/ext/ai_writer/document.py @@ -0,0 +1,50 @@ +import fire +from pathlib import Path +from metagpt.const import DATA_PATH +from metagpt.rag.engines import SimpleEngine +from metagpt.rag.schema import FAISSRetrieverConfig, BM25RetrieverConfig, BGERerankConfig,FAISSIndexConfig +from metagpt.ext.ai_writer.writer import DocumentWriter +from metagpt.ext.ai_writer.utils import WriteOutFile + +def build_engine(ref_dir: Path, persist_dir: Path, model_name: str = "bge-large-zh") -> SimpleEngine: + retriever_configs = [FAISSRetrieverConfig(similarity_top_k=10), BM25RetrieverConfig(similarity_top_k=20)] + ranker_configs = [BGERerankConfig(model=model_name, top_n=10)] + + if persist_dir.exists(): + engine = SimpleEngine.from_index(index_config=FAISSIndexConfig(persist_path=persist_dir), + retriever_configs=retriever_configs, + ranker_configs=ranker_configs) + else: + engine = SimpleEngine.from_docs(input_dir=ref_dir, + retriever_configs=retriever_configs, + ranker_configs=ranker_configs) + engine.retriever.persist(persist_dir=persist_dir) + return engine + + + +REQUIREMENT = "写一个完整、连贯的《{topic}》文档, 确保文字精确、逻辑清晰,并保持专业和客观的写作风格。" +topic = "产业数字化对中国出口隐含碳的影响" + +async def main(auto_run: bool = False, use_store:bool = True): + ref_dir = DATA_PATH / f"ai_writer/ref/{topic}" + persist_dir = DATA_PATH / f"ai_writer/persist/{topic}" + output_path = DATA_PATH / f"ai_writer/outputs" + model ='model/bge-large-zh-v1.5' + + if use_store: + store = build_engine(ref_dir,persist_dir,model) + dw = DocumentWriter(auto_run=auto_run,store = store) + else: + dw = DocumentWriter(auto_run=auto_run,use_store = False) + + requirement = REQUIREMENT.format(topic = topic) + await dw.run(requirement) + + # write out word + WriteOutFile.write_word_file(topic = topic, + tasks= dw.planner.titlehierarchy.traverse_and_output(), + output_path = output_path) + +if __name__ == "__main__": + fire.Fire(main) diff --git a/metagpt/ext/ai_writer/utils.py b/metagpt/ext/ai_writer/utils.py new file mode 100644 index 000000000..0e71c71dc --- /dev/null +++ b/metagpt/ext/ai_writer/utils.py @@ -0,0 +1,211 @@ +from pydantic import BaseModel +from typing import List, Union, Any +from pathlib import Path +import re +from datetime import datetime +from docx import Document + +def colored_decorator(color): + def print_colored(func): + async def wrapper(self, *args, **kwargs): + # ANSI escape codes for colored terminal output (optional, can be removed if not needed) + print(color) + contexts = await func(self, *args, **kwargs) + print("\033[0m") + return contexts + return wrapper + return print_colored + + +def print_time(func): + async def wrapper(self, *args, **kwargs): + # ANSI escape codes for colored terminal output (optional, can be removed if not needed) + start_time = datetime.now() + objs = await func(self, *args, **kwargs) + end_time = datetime.now() + print(f"执行时间:{(end_time - start_time).total_seconds()}秒") + return objs + return wrapper + + +class Node: + def __init__(self, name='',content = ''): + self.name = name + self.content = content + self.subheading = {} + +class TitleHierarchy: + def __init__(self): + self.root = Node() + + def add_chapter(self, id: str, name: str): + # Create a new node for the chapter + chapter_node = Node(name) + # Add the new chapter node to the root's subheading + self.root.subheading[id] = chapter_node + + + def add_title(self, title: str): + parts = title.split(' ') + path = parts[0].split('.') + name = ' '.join(parts[1:]) + current = self.root + + for part in path: + if part not in current.subheading: + current.subheading[part] = Node() + current = current.subheading[part] + + current.name = name + + def add_subheadings(self, titles): + """ + Add multiple subheadings to the hierarchy. + + param titles: A list of title strings to be added to the hierarchy. + """ + for title in titles: + self.add_title(title) + + + def get_subheadings_by_prefix(self, prefix): + """ + Get subheadings based on the given prefix. + + param root: The root Node of the hierarchical structure representing the relationship between titles and subtitles. + param prefix: The prefix string to filter the subheadings. + return: A list of subheading names that match the given prefix. + """ + # Split the prefix into its components + prefix_parts = prefix.split('.') + + # Navigate through the structure using the prefix + current = self.root + for part in prefix_parts: + if part in current.subheading: + current = current.subheading[part] + else: + # If any part of the prefix is not found, return an empty list + return [] + + # Collect all subheading names at this level + subheadings = [] + for key, value in current.subheading.items(): + subheadings.append(f"{prefix}.{key} {value.name}") + + return subheadings + + def set_content_by_id(self, path: str, content: str): + """ + Set the content of a node identified by its path. + + param path: A string representing the path to the node, formatted as "x.y.z...". + param content: The content to be set for the node. + """ + # Split the path into its components + path_parts = path.split('.') + + # Navigate through the structure using the path + current = self.root + for part in path_parts: + if part in current.subheading: + current = current.subheading[part] + else: + # If any part of the path does not exist, return without setting content + return + + # Set the content of the node + current.content = content + + def set_content_by_headings(self, titles, contents): + """ + Set the contents of nodes identified by their paths. + + param titles: A list of title strings representing the paths. + param contents: A list of contents corresponding to the titles. + """ + for title, content in zip(titles, contents): + path, _ = title.split(' ') + self.set_content_by_id(path, content) + + def traverse_and_output(self, node=None, prefix='', level=0): + if node is None: + node = self.root + + output = [] + for key, child in node.subheading.items(): + title = f"{prefix}{key} {child.name}" if prefix else f"{key} {child.name}" + output.append((title, child.content, level + 1)) + output.extend(self.traverse_and_output(child, f"{prefix}{key}.", level + 1)) + + return output + + def get_chapter_obj_by_id(self, id: str) -> str: + """ + Get the chapter name by its hierarchical ID. + + param id: The hierarchical ID of the chapter to look up, formatted as "x.y.z...". + return: The name of the chapter if found, otherwise an empty string. + """ + # Split the hierarchical ID into its components + id_parts = id.split('.') + # Start from the root and navigate through the structure using the ID + current = self.root + for part in id_parts: + if part in current.subheading: + current = current.subheading[part] + else: + # If any part of the ID is not found, return an empty string + return '' + # Return the name of the chapter node + return current + + + +class WriteOutFile: + + @staticmethod + def write_markdown_file(topic: str, tasks: Any, output_path: Union[str, Path]): + pass + + @staticmethod + def write_word_file(topic: str, tasks: Any, output_path: Union[str, Path]): + """ + Writes tasks to a Word document. + + topic (str): The main topic of the document. + tasks (List[Tuple[str, str, int]]): A list of tuples containing the title, content, and heading level of each task. + output_path (Union[str, Path]): The file path where the document will be saved. + """ + def post_processes(title: str, context: str) -> str: + """Post-processes the context by removing the first line that match the title.""" + normalize_text = lambda x: re.sub(r'[^\u4e00-\u9fa5]+', '', x) + split_context = context.split('\n') if context else [] + prefix = split_context[0] if split_context else '' + if normalize_text(title) == normalize_text(prefix): + context = '\n'.join(split_context[1:]) + return context + + # Ensure the write_out_file is a Path object + output_path = Path(output_path) + document = Document() + document.add_heading(topic, level=0) + + # Process each task + for title, content, level in tasks: + # Add a heading for the task + document.add_heading(title, level=level) + # Post-process and add it to the document + content = post_processes(title, content) + document.add_paragraph(content) + + document.add_page_break() + # Save the document + try: + document.save(output_path) + except Exception as e: + print(f"An error occurred while saving the document: {e}") + + + + \ No newline at end of file diff --git a/metagpt/ext/ai_writer/write_planner.py b/metagpt/ext/ai_writer/write_planner.py new file mode 100644 index 000000000..dab7f1714 --- /dev/null +++ b/metagpt/ext/ai_writer/write_planner.py @@ -0,0 +1,197 @@ +from __future__ import annotations +import json +import re +from typing import Any +from metagpt.actions.di.write_plan import ( + precheck_update_plan_from_rsp, + update_plan_from_rsp, +) +from metagpt.actions.di.ask_review import ReviewConst +from metagpt.strategy.planner import Planner +from metagpt.logs import logger +from metagpt.schema import Message, Plan +from metagpt.actions import Action +from metagpt.utils.common import CodeParser +from metagpt.const import METAGPT_ROOT,DATA_PATH +from metagpt.ext.ai_writer.utils import colored_decorator,TitleHierarchy + +class DocumentPlan(Action): + PROMPT_TEMPLATE: str = """ + # Context: + {context} + # Task: + Based on the topic, write a outline or modify an existing outline of what you should do to achieve the goal. A outline consists of one to {max_tasks} chapters. + If you are modifying an existing chapter, carefully follow the instruction, don't make unnecessary changes. Give the whole chapters unless instructed to modify only one chapter of the outline. + If you encounter errors on the current chapter, revise and output the current single chapter only. + Output a list of jsons following the format: + ```json + [ + {{ + "chapter_id": str = "unique identifier for a chapter in outline, can be an ordinal", + "chapter_name": str = "current chapter title in the outline", + "subheadings":list[str] = "this chapter is divided into several smaller sections with a finer level of detail (e.g., 1.1, 1.1.1, 1.1.2, 2.1)." + + }}, + ... + ] + ``` + """ + + async def run(self, context: list[Message], max_tasks: int = 7, human_design_planner: bool = False) -> str: + + if human_design_planner: + with open(DATA_PATH/ 'ai_writer/outlines/human_design_planner.json', 'r', encoding='utf-8') as file: + rsp = json.dumps(json.loads(file.read()), ensure_ascii=False) + else: + prompt = self.PROMPT_TEMPLATE.format( + context="\n".join([str(ct) for ct in context]), max_tasks=max_tasks) + rsp = await self._aask(prompt) + rsp = CodeParser.parse_code(block=None, text=rsp) + return rsp + + + +class WritePlanner(Planner): + human_design_planner: bool = False + titlehierarchy:Any = None + async def update_plan(self, goal: str = "", context: str = "", max_tasks: int = 7, max_retries: int = 3): + if goal: + self.plan = Plan(goal=goal, context = context) + plan_confirmed = False + while not plan_confirmed: + context = self.get_useful_memories() + rsp = await DocumentPlan().run(context, max_tasks, self.human_design_planner) + self.working_memory.add(Message(content=rsp, role="assistant", cause_by = DocumentPlan)) + + rsp = self.post_process_chapter_id_or_name(rsp) + # precheck plan before asking reviews + is_plan_valid, error = precheck_update_plan_from_rsp(rsp, self.plan) + plan_valid = self.precheck_from_rsp(rsp) + if not (is_plan_valid and plan_valid) and max_retries > 0: + error_msg = f"The generated plan is not valid with error: {error}, try regenerating, remember to generate either the whole plan or the single changed task only" + logger.warning(error_msg) + self.working_memory.add(Message(content=error_msg, role="assistant", cause_by= DocumentPlan)) + max_retries -= 1 + continue + _, plan_confirmed = await self.ask_review(trigger=ReviewConst.TASK_REVIEW_TRIGGER) + update_plan_from_rsp(rsp=rsp, current_plan=self.plan) + self.titlehierarchy = self.process_response_and_build_hierarchy(rsp=rsp) + self.working_memory.clear() + + + def post_process_chapter_id_or_name(self,rsp): + """ + Post-process the response to replace chapter_id and chapter_name with task_id and instruction. + This method takes a response (rsp) and replaces any occurrences of "chapter_id" with "task_id" + and "chapter_name" with "instruction". + + This is useful when the response contains references + to chapter identifiers and names that need to be updated to match the task's attributes. + """ + # chapter_id save in current_task.task_id + # chapter_name save in current_task.instruction + rsp = rsp.replace("chapter_id","task_id").replace("chapter_name","instruction") + return rsp + + + def precheck_from_rsp(self, rsp): + """ + Perform a pre-check on the response data to ensure it meets the expected format. + """ + try: + rsp = json.loads(rsp) + return all([isinstance(task_config["subheadings"],list) for task_config in rsp]) + except Exception: + return False + + + + def process_response_and_build_hierarchy(self,rsp = '')-> TitleHierarchy: + """ + Post-process the response data to update the title hierarchy. + """ + titlehierarchy = TitleHierarchy() + rsp = json.loads(rsp) + for element in rsp: + task_id = re.sub(r'\.\d+', '', element['task_id']) + titlehierarchy.add_chapter(task_id, element['instruction']) + titlehierarchy.add_subheadings(element['subheadings']) + return titlehierarchy + + + + async def process_task_result(self,task_result): + """ + Process the task result and ask the user for confirmation. + This method processes the given task result and prompts the user for confirmation. + If the user confirms, it calls the `confirm_task` method with the current task,the task result, and the review. + """ + review, task_result_confirmed = await self.ask_user_instruction(finished = True) + if task_result_confirmed: + await self.confirm_task(self.current_task, task_result, review) + + + async def ask_for_review(self, prompt:str): + """ + Prompt the user for review input and handle the response. + This method interacts with the user to collect their review input. If the auto_run + flag is not set, it prompts the user with the given prompt and waits for their input. + If the input matches certain exit words, the program will exit. Otherwise, the input + is processed and returned. + Parameters: + prompt (str): The prompt to display to the user. Default is an empty string. + Returns: + tuple: A tuple containing the user's response and a boolean indicating whether the input was confirmed. + """ + if not self.auto_run: + rsp = input(prompt) + if rsp.lower() in ReviewConst.EXIT_WORDS: + exit() + confirmed = rsp.lower() in ReviewConst.CONTINUE_WORDS or ReviewConst.CONTINUE_WORDS[0] in rsp.lower() + if not confirmed: + self.working_memory.add(Message(content=rsp, role="user")) + else: + rsp = '' + + return rsp, confirmed + else: + return '' , True + + + @colored_decorator('\033[1;30;47m') + async def ask_user_context(self): + context = ( + f"Please add detailed writing context to enhance the text.\n" + f"No relevant content to elaborate on, just respond with 'yes'.\n" + ) + context, c_confirmed = await self.ask_for_review(context) + return context, c_confirmed + + @colored_decorator('\033[1;30;42m') + async def ask_user_instruction(self, finished = False): + if finished: + instruction = f"Does the chapter's paragraph meet your requirements? Reply 'yes' or 'no'.\n" + else: + instruction = ( + f"Please add instruction to enhance the text. \n" + f"no instruction, simply reply with yes." + f"If you wish to revert to the original, just respond with 'redo'.\n" + ) + # Calculate confirmed based on the context and instruction received + instruction, i_confirmed = await self.ask_for_review(instruction) + return instruction, i_confirmed + + + async def ask_review_template(self): + """ + Ask the user for context and instruction to enhance the text of a given task. + This method prompts the user to provide context and instruction to enhance the text + of the given task. + """ + context, c_confirmed = await self.ask_user_context() + instruction, i_confirmed = await self.ask_user_instruction() + confirmed = ((c_confirmed or not context) and i_confirmed) or (not context and not instruction) + return instruction, context, confirmed + + + diff --git a/metagpt/ext/ai_writer/write_refine.py b/metagpt/ext/ai_writer/write_refine.py new file mode 100644 index 000000000..637cc824b --- /dev/null +++ b/metagpt/ext/ai_writer/write_refine.py @@ -0,0 +1,151 @@ +from __future__ import annotations +from metagpt.actions import Action +from metagpt.schema import Message +BEGGING_PROMPT = """ +# user requirements + {user_requirement} +# the current chapter title + {chapter_name} +# the subheadings + {subheadings} +# Selective reference + {contexts} +# Task: + Based on the user's specified writing topic and considering the current chapter title, its subheadings, and selective reference, + assist in crafting an introductory preamble for this chapter. + The preamble should serve as a guiding statement that sets the stage for the content to follow, highlighting the central theme and providing a roadmap for the subheadings. + +# Constraint: + 1. Integrate insights from the subheadings to enrich the content. + 2. Keep the expansion to approximately 10 sentences for brevity. + 3. Ensure the response is clear and concise, avoiding unnecessary elaboration. + 4. The response should be provided in {language} to align with the user's requirements. +""" + +class WriteGuide(Action): + async def run( self, + user_requirement: str, + chapter_name : str, + subheadings : str, + contexts : str, + language : str = 'chinese' + ) -> str: + structual_prompt = BEGGING_PROMPT.format( + user_requirement=user_requirement, + chapter_name = chapter_name, + subheadings = subheadings, + contexts = contexts, + language = language + ) + + context = self.llm.format_msg([Message(content=structual_prompt, role="user")]) + rsp = await self.llm.aask(context) + return rsp + +REFINE = """ +# user's guidance: + {user_requirement} + +# original answer: + {respones} + +# Additional Context: + {contexts} + +# task: + Enhance the clarity and effectiveness of the original answer to the query “{original_query}” by incorporating the user's guidance and the additional context provided. + +# Constraints: +1. If the additional context proves irrelevant or lacks substance, default to providing the initial response without modification. +2. Ensure that the refined answer is coherently integrated and maintains a high level of logical flow and readability. + +# Instructions: +- The response should be provided in {language} to align with the user's requirements. +- Carefully review the user's guidance to understand the desired improvements. +- Assess the additional context for its relevance and potential to enhance the original answer. +- Modify the original answer by adding, removing, or rephrasing content as necessary to align with the user's guidance and the new context. +- Proofread the refined answer to ensure it is clear, concise, and effectively addresses the original query. +- Maintain the integrity of the original response while incorporating the enhancements. +""" + +class Refine(Action): + async def run( self, + original_query: str, + respones : str, + contexts : str, + user_requirement : str = '', + language : str = 'chinese', + **kwargs) -> str: + original_query = f"{original_query}" + structual_prompt = REFINE.format( + original_query = original_query, + respones = respones, + contexts = contexts, + user_requirement = user_requirement, + language = language + ) + context = self.llm.format_msg([Message(content=structual_prompt, role="user")]) + rsp = await self.llm.aask(context, **kwargs) + return rsp + + +SUBSECTION_PROMPT = """ +# user requirements: + To generate a comprehensive paragraph that elaborates on the given subsection heading within the provided context. + +# the subsection heading: Begin by identifying the specific subsection heading provided. + {subsection} + +# Contextual Preamble: Consider the introductory context that sets the stage for the subsection. + {contexts} + +# Writing Guidelines: + Step 1: Reflect on the subsection heading and how it relates to the preamble. + Step 2: Develop a coherent paragraph that builds upon the preamble and directly addresses the subsection heading. + Step 3: Ensure the paragraph is enriched with relevant details, examples, or explanations that enhance understanding. + Step 4: Review the paragraph for clarity, coherence, and adherence to the subsection's focus. + +# Instructions: + 1. Align with Subsection Heading: + Ensure that your content directly corresponds to the topic outlined by the subsection heading. This alignment is crucial for maintaining coherence and relevance throughout the document. + 2. Follow Specified Headings: + Strictly adhere to the headings provided. Each section should be crafted to specifically address and expand upon the ideas presented in these headings. + 3. Focus on Content Depth: + Concentrate on developing the substance of your writing around the core theme indicated by the subsection heading. Where applicable, enrich your discussion with relevant references to support your arguments and enhance credibility. +""" + +class WriteSubsection(Action): + async def run( self, subsection : str, contexts : str, **kwargs) -> str: + structual_prompt = SUBSECTION_PROMPT.format( subsection = subsection, + contexts = contexts + ) + context = self.llm.format_msg([Message(content=structual_prompt, role="user")]) + rsp = await self.llm.aask(context, **kwargs) + return rsp + + +CLC_PROMPT = """ +# Title: + {title} +# Provided contexts: + ``` + {contexts} + ``` +# task: + Extract key information pertinent to the title from the provided context to ensure that the data extracted accurately reflects the subject matter. + If the context contains no relevant information, return an empty result. Focus on filtering out any irrelevant or non-useful data to maintain the accuracy and relevance of the extracted content. + +# Constraints: + 1. Keep the expansion to approximately 5 sentences for brevity. + 2. The response should be provided in {language}. +""" +class Clean(Action): + async def run(self,title: str,contexts: str, language:str = 'chinese') -> str: + structual_prompt = CLC_PROMPT.format( + title = title, + contexts = contexts, + language = language + ) + context = self.llm.format_msg([Message(content=structual_prompt, role="user")]) + rsp = await self.llm.aask(context) + return rsp \ No newline at end of file diff --git a/metagpt/ext/ai_writer/writer.py b/metagpt/ext/ai_writer/writer.py new file mode 100644 index 000000000..5881dba60 --- /dev/null +++ b/metagpt/ext/ai_writer/writer.py @@ -0,0 +1,277 @@ +from __future__ import annotations +from typing import Literal,List, Optional +import re +from pydantic import Field, model_validator +from metagpt.logs import logger +from metagpt.roles import Role +from metagpt.tools.search_engine import SearchEngine +from metagpt.schema import Message, Task, TaskResult +from llama_index.core import Document +from metagpt.ext.ai_writer.write_refine import WriteGuide,Refine, WriteSubsection, Clean +from metagpt.actions import SearchAndSummarize, UserRequirement +from metagpt.ext.ai_writer.write_planner import WritePlanner +from metagpt.ext.ai_writer.utils import colored_decorator, print_time + + +class DocumentWriter(Role): + name: str = "wangs" + profile: str = "document_writer" + goal: str = "write a long document" + auto_run: bool = True + use_plan: bool = True + react_mode: Literal["plan_and_act", "react"] = 'plan_and_act' #"by_order" + human_design_planner: bool = False + max_react_loop : int = 5 + planner: WritePlanner = Field(default_factory= WritePlanner) + use_store:bool = True + store: Optional[object] = Field(default=None, exclude=True) + store_mode: Literal["search", "retrieve", "retrieve_clean"] = "retrieve_clean" + + + @model_validator(mode="after") + def set_plan(self): + self._set_react_mode(react_mode=self.react_mode, max_react_loop=self.max_react_loop, auto_run=self.auto_run) + self.planner = WritePlanner(auto_run=self.auto_run, + human_design_planner = self.human_design_planner + ) + + self.use_plan = (self.react_mode == "plan_and_act" ) + self.set_actions([WriteGuide, Refine, WriteSubsection]) + self._set_state(0) + return self + + @model_validator(mode="after") + def validate_stroe(self): + if self.store: + search_engine = SearchEngine.from_search_func(search_func=self.store.asearch, proxy=self.config.proxy) + action = SearchAndSummarize(search_engine=search_engine, context=self.context) + else: + action = SearchAndSummarize(context=self.context) + self.actions.append(action) + return self + + @property + def working_memory(self): + return self.rc.working_memory + + @print_time + async def run(self, requirement) -> Message | None: + return await super().run(requirement) + + async def _act_on_task(self, current_task: Task) -> TaskResult: + """ + Process a given task by either writing an initial draft or refining a draft based on user feedback . + This method differentiates the action to be taken based on the completion status of the task. + If the task is marked as finished, it will refine the draft; otherwise, it will write a new draft. + """ + id = re.sub(r'\.\d+', '', current_task.task_id) + chapter_name = current_task.instruction + first_subheadings = self.planner.titlehierarchy.get_subheadings_by_prefix(id) + + if not current_task.is_finished: + await self.write_draft(id, chapter_name, first_subheadings) + else: + await self.refine_draft() + + current_task.is_finished = True + task_result = TaskResult( + result = '', + is_success = True + ) + return task_result + + + async def refine_draft(self): + """ + Refine the draft based on user feedback associated with the task. + """ + pass + + + async def write_draft(self, id:str, chapter_name:str, subheadings:List[str]): + """ + this function is responsible for creating an initial draft of a chapter based on provided information. + + Parameters: + - `id` (str): Unique identifier for the chapter. + - `chapter_name` (str): The title of the chapter to be written. + - `subheadings` (List[str]): A list of subheading titles for the chapter's content. + """ + if subheadings: + self._set_state(0) + guidelines, cause_by = await self.generate_guide(chapter_name,subheadings) + self.working_memory.add(Message(content= guidelines, role="assistant", cause_by=cause_by)) + + self._set_state(1) + guidelines, cause_by = await self.refine_guide(chapter_name,guidelines) + self.working_memory.add(Message(content= guidelines, role="assistant", cause_by=cause_by)) + else: + subheadings, guidelines = [f"{id} {chapter_name}"], '' + + # No subheadings (subdirectories), no need to write guides, directly generated content. + self._set_state(2) + subheadings, gen_subsection, cause_by = await self.generate_parallel_subsection(subheadings,guidelines) + self.working_memory.add_batch([Message(content = x, role="assistant", cause_by=cause_by) for x in gen_subsection]) + + # Set guidelines and subsections into the heading hierarchy, an object used for document structure management. + self.planner.titlehierarchy.set_content_by_id(id, guidelines) + self.planner.titlehierarchy.set_content_by_headings(subheadings,gen_subsection) + + # Add the generated document to the node for easier retrieval + await self.add_nodes_doc() + # Log completion of actions + logger.info(f"All actions for chapter {chapter_name} have been finished.") + + + async def generate_guide(self, chapter_name:str,subheadings:List[str]): + """ + Generate the beginning guidelines of a chapter based on the current task and user requirements. + This method initializes the generation process by checking for the existence of necessary + attributes, retrieving user requirements, and context, then executing a task to generate + the beginning of a document. + + Returns: + tuple: A tuple containing the generated beginning of the chapter and the todo object used. + """ + if not hasattr(self, 'rc') or not hasattr(self.rc, 'todo'): + raise AttributeError("Expected 'rc' and 'rc.todo' to be initialized") + + todo = self.rc.todo + user_requirement = self.get_memories()[0].content + logger.info("Starting to retrieve relevant information from documents") + contexts = await self.doc_retrieve(chapter_name) + logger.info("Starting to write the opening paragraph") + guidelines = await todo.run( + user_requirement=user_requirement, + chapter_name = chapter_name, + subheadings = ','.join([section for section in subheadings]), + contexts = contexts + ) + return guidelines, todo + + + async def generate_parallel_subsection(self,subheadings,guidelines): + """ + Generate parallel subsections of a chapter concurrently based on the current task. + This method generates multiple subsections of a chapter in parallel, using the task's + dependent task IDs and the 'todo' object. + It retrievescontexts for each subsection, and then runs the 'todo' task concurrently for each subsection. + + + Returns: + tuple: A tuple containing a list of results from the parallel tasks and the 'todo' object used. + """ + + if not hasattr(self, 'rc') or not hasattr(self.rc, 'todo'): + raise AttributeError("Expected 'rc' and 'rc.todo' to be initialized") + if not isinstance(subheadings, list): + subheadings = [subheadings] + + todo = self.rc.todo + logger.info(f"ready to {todo.name}") + + context, first_headings = [], [] + for section in subheadings: + _id, name = section.split(' ') + subtitle = self.planner.titlehierarchy.get_subheadings_by_prefix(_id) + if subtitle: + await self.write_draft(_id, name,subtitle) + else: + contexts = await self.doc_retrieve(name) + gen_subsection= await todo.run( + subsection = section, + contexts = f'{guidelines}\n\n# Reference: \n```{contexts}```' + ) + first_headings.append(section) + context.append(gen_subsection) + # context = await asyncio.gather(*context) + return first_headings, context , todo + + + async def refine_guide(self,chapter_name,guidelines): + """ + This method refines the beginning of a chapter by repeatedly asking for user review + and instructions until the user confirms that the result is satisfactory. + + Returns: + tuple: A tuple containing the refined beginning of the document and the 'todo' object used. + """ + instruction, context, confirmed = await self.planner.ask_review_template() + if not hasattr(self, 'rc') or not hasattr(self.rc, 'todo'): + raise AttributeError("Expected 'rc' and 'rc.todo' to be initialized") + todo = self.rc.todo + refine_sets, cur_result = [], guidelines + while not confirmed: + if instruction == 'redo': + refine_sets, cur_result = [], guidelines + logger.info(f"Redo finished, revert to the original response.") + instruction, context, confirmed = await self.planner.ask_review_template() + continue + refine_sets.append(cur_result) + pre_result = '\n\n'.join(refine_sets[-3:]) # Retain up to 2 rounds of results + cur_result = await Refine().run( + user_requirement = instruction, + original_query = chapter_name, + respones = pre_result, + contexts = context + ) + + instruction, confirmed = await self.planner.ask_user_instruction() + return cur_result, todo + + @colored_decorator("\033[1;46m") + async def doc_retrieve(self,title): + """ + Retrieve relevant information from documents based on the given title. + + This method retrieves contexts from documents using an engine, which could be a retrieval or a RAG (Retrieval Augmented Generation) model. + The retrieval mode is determined by the `rag_mode` attribute of the class instance. + The retrieved contexts are then optionally cleaned if the mode is set to 'retrieve_gen'. + Parameters: + title (str): The title or query used to retrieve relevant information from documents. + Returns: + contexts: A string containing the retrieved contexts, separated by section dividers. + """ + if not self.use_store: + return '' + + contexts = '' + if self.store_mode == 'search': + todo = self.actions[-1] + prompt = f'Please help me write the content of chapter "{title}"' if self.store else title + message = [Message(content= prompt, role="user", cause_by = UserRequirement)] + contexts = await todo.run(self.rc.history + message) + + if self.store_mode.startswith('retrieve'): + # Retrieve contexts using the store's aretrieve method + contexts = await self.store.aretrieve(title) + contexts = '\n\n'.join([x.text for x in contexts]) + if self.store_mode == 'retrieve_clean': + contexts = await Clean().run(title = title, contexts = contexts) + return contexts + + + async def add_nodes_doc(self): + """ + Add nodes to the document retriever engine based on the assistant's messages from memory. + This method extracts text chunks from the working memory where the role is "assistant", + creates Document objects with these chunks, and then adds these documents as nodes to the + retriever engine. + """ + if not self.use_store or not self.store: + return + text_chunks = [message.content for message in self.working_memory.get() if message.role == "assistant"] + doc_chunks = [] + for i, text in enumerate(set(text_chunks)): + doc = Document(text=text, id_=f"doc_id_{i}") + doc_chunks.append(doc) + # Add the list of Document objects as nodes to the retriever engine + self.store.retriever.add_nodes(doc_chunks) + self.working_memory.clear() + return + + + + + +