Skip to content

Commit

Permalink
chore: shorten prompt and adjust the prompt to reduce instruction con…
Browse files Browse the repository at this point in the history
…fusion (#54)

* chore: shorten prompt and adjust the prompt to reduce instruction confusion

* chore: improve recap summarize prompt
Co-authored-by: OverflowCat <[email protected]>

* chore: improve input context
Co-authored-by: OverflowCat <[email protected]>

* chore: improve llm friendly context
Co-authored-by: OverflowCat <[email protected]>

---------

Co-authored-by: OverflowCat <[email protected]>
  • Loading branch information
nekomeowww and OverflowCat authored May 6, 2023
1 parent 53ddb31 commit bd444e6
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 35 deletions.
39 changes: 26 additions & 13 deletions internal/models/chathistories/chat_histories.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/nekomeowww/insights-bot/pkg/bots/tgbot"
"github.com/nekomeowww/insights-bot/pkg/logger"
"github.com/nekomeowww/insights-bot/pkg/openai"
"github.com/nekomeowww/insights-bot/pkg/utils"
)

type NewModelParams struct {
Expand Down Expand Up @@ -93,9 +94,9 @@ func (m *Model) SaveOneTelegramChatHistory(message *tgbotapi.Message) error {
return nil
}
if message.ForwardFrom != nil {
telegramChatHistoryCreate.SetText("转发了来自" + tgbot.FullNameFromFirstAndLastName(message.ForwardFrom.FirstName, message.ForwardFrom.LastName) + "的消息:" + text)
telegramChatHistoryCreate.SetText(fmt.Sprintf("[forwarded from %s]: %s", tgbot.FullNameFromFirstAndLastName(message.ForwardFrom.FirstName, message.ForwardFrom.LastName), text))
} else if message.ForwardFromChat != nil {
telegramChatHistoryCreate.SetText("转发了来自" + message.ForwardFromChat.Title + "的消息:" + text)
telegramChatHistoryCreate.SetText(fmt.Sprintf("[forwarded from %s]: %s", message.ForwardFromChat.Title, text))
} else {
telegramChatHistoryCreate.SetText(text)
}
Expand Down Expand Up @@ -164,6 +165,10 @@ func formatFullNameAndUsername(fullName, username string) string {
return strings.ReplaceAll(fullName, "#", "")
}

func formatChatHistoryTextContent(text string) string {
return fmt.Sprintf(`"""%s"""`, text)
}

type RecapOutputTemplateInputs struct {
ChatID string
Recaps []*openai.ChatHistorySummarizationOutputs
Expand All @@ -186,10 +191,10 @@ var RecapOutputTemplate = lo.Must(template.
"add": func(a, b int) int { return a + b },
"escape": tgbot.EscapeHTMLSymbols,
}).
Parse(`{{ $chatID := .ChatID }}{{ $recapLen := len .Recaps }}{{ range $i, $r := .Recaps }}{{ if $r.SinceMsgID }}## <a href="https://t.me/c/{{ $chatID }}/{{ $r.SinceMsgID }}">{{ escape $r.TopicName }}</a>{{ else }}## {{ escape $r.TopicName }}{{ end }}
Parse(`{{ $chatID := .ChatID }}{{ $recapLen := len .Recaps }}{{ range $i, $r := .Recaps }}{{ if $r.SinceID }}## <a href="https://t.me/c/{{ $chatID }}/{{ $r.SinceID }}">{{ escape $r.TopicName }}</a>{{ else }}## {{ escape $r.TopicName }}{{ end }}
参与人:{{ join $r.ParticipantsNamesWithoutUsername "," }}
讨论:{{ range $di, $d := $r.Discussion }}
- {{ escape $d.Point }}{{ if len $d.CriticalMessageIDs }} {{ range $cIndex, $c := $d.CriticalMessageIDs }}<a href="https://t.me/c/{{ $chatID }}/{{ $c }}">[{{ add $cIndex 1 }}]</a>{{ if not (eq $cIndex (sub (len $d.CriticalMessageIDs) 1)) }} {{ end }}{{ end }}{{ end }}{{ end }}{{ if $r.Conclusion }}
- {{ escape $d.Point }}{{ if len $d.CriticalIDs }} {{ range $cIndex, $c := $d.CriticalIDs }}<a href="https://t.me/c/{{ $chatID }}/{{ $c }}">[{{ add $cIndex 1 }}]</a>{{ if not (eq $cIndex (sub (len $d.CriticalIDs) 1)) }} {{ end }}{{ end }}{{ end }}{{ end }}{{ if $r.Conclusion }}
结论:{{ escape $r.Conclusion }}{{ end }}{{ if eq $i (sub $recapLen 1) }}{{ else }}
{{ end }}{{ end }}`))
Expand Down Expand Up @@ -222,6 +227,8 @@ func (m *Model) summarizeChatHistoriesSlice(s string) ([]*openai.ChatHistorySumm
return nil, err
}

m.logger.Infof("✅ unmarshaled chat history summarization output: %s", utils.SprintJSON(outputs))

return outputs, nil
}

Expand All @@ -231,19 +238,23 @@ func (m *Model) SummarizeChatHistories(chatID int64, histories []*ent.ChatHistor
for _, message := range histories {
if message.RepliedToMessageID == 0 {
historiesLLMFriendly = append(historiesLLMFriendly, fmt.Sprintf(
"msgId:%d: %s 发送:%s",
"msgId:%d: %s sent: %s",
message.MessageID,
formatFullNameAndUsername(message.FullName, message.Username),
message.Text,
formatChatHistoryTextContent(message.Text),
))
} else {
repliedToPartialContextMessage := fmt.Sprintf("%s 发送的 msgId:%d 的消息", formatFullNameAndUsername(message.RepliedToFullName, message.RepliedToUsername), message.RepliedToMessageID)
repliedToPartialContextMessage := fmt.Sprintf(
"%s sent msgId:%d",
formatFullNameAndUsername(message.RepliedToFullName, message.RepliedToUsername),
message.RepliedToMessageID,
)
historiesLLMFriendly = append(historiesLLMFriendly, fmt.Sprintf(
"msgId:%d: %s 回复 %s:%s",
"msgId:%d: %s replying to [%s]: %s",
message.MessageID,
formatFullNameAndUsername(message.FullName, message.Username),
repliedToPartialContextMessage,
message.Text,
formatChatHistoryTextContent(message.Text),
))
}
}
Expand Down Expand Up @@ -274,15 +285,15 @@ func (m *Model) SummarizeChatHistories(chatID int64, histories []*ent.ChatHistor

for _, o := range outputs {
for _, d := range o.Discussion {
d.CriticalMessageIDs = lo.UniqBy(d.CriticalMessageIDs, func(item int64) int64 {
d.CriticalIDs = lo.UniqBy(d.CriticalIDs, func(item int64) int64 {
return item
})
d.CriticalMessageIDs = lo.Filter(d.CriticalMessageIDs, func(item int64, _ int) bool {
d.CriticalIDs = lo.Filter(d.CriticalIDs, func(item int64, _ int) bool {
return item != 0
})

if len(d.CriticalMessageIDs) > 5 {
d.CriticalMessageIDs = d.CriticalMessageIDs[:5]
if len(d.CriticalIDs) > 5 {
d.CriticalIDs = d.CriticalIDs[:5]
}
}
}
Expand All @@ -300,5 +311,7 @@ func (m *Model) SummarizeChatHistories(chatID int64, histories []*ent.ChatHistor
return "", err
}

m.logger.Infof("✅ summarized chat histories: %s", sb.String())

return sb.String(), nil
}
16 changes: 8 additions & 8 deletions internal/models/chathistories/chat_histories_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,12 @@ func TestRecapOutputTemplateExecute(t *testing.T) {
Recaps: []*openai.ChatHistorySummarizationOutputs{
{
TopicName: "Topic 1",
SinceMsgID: 1,
SinceID: 1,
ParticipantsNamesWithoutUsername: []string{"User 1", "User 2"},
Discussion: []*openai.ChatHistorySummarizationOutputsDiscussion{
{
Point: "Point 1",
CriticalMessageIDs: []int64{1, 2},
Point: "Point 1",
CriticalIDs: []int64{1, 2},
},
{
Point: "Point 2",
Expand All @@ -173,19 +173,19 @@ func TestRecapOutputTemplateExecute(t *testing.T) {
Point: "Point 1",
},
{
Point: "Point 2",
CriticalMessageIDs: []int64{1, 2},
Point: "Point 2",
CriticalIDs: []int64{1, 2},
},
},
},
{
TopicName: "Topic 1",
SinceMsgID: 2,
SinceID: 2,
ParticipantsNamesWithoutUsername: []string{"User 1", "User 2"},
Discussion: []*openai.ChatHistorySummarizationOutputsDiscussion{
{
Point: "Point 1",
CriticalMessageIDs: []int64{1, 2},
Point: "Point 1",
CriticalIDs: []int64{1, 2},
},
{
Point: "Point 2",
Expand Down
21 changes: 7 additions & 14 deletions pkg/openai/prompts.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,22 @@ type ChatHistorySummarizationPromptInputs struct {
}

type ChatHistorySummarizationOutputsDiscussion struct {
Point string `json:"point"`
CriticalMessageIDs []int64 `json:"criticalMsgIds"`
Point string `json:"point"`
CriticalIDs []int64 `json:"criticalIds"`
}

type ChatHistorySummarizationOutputs struct {
TopicName string `json:"topicName"`
SinceMsgID int64 `json:"sinceMsgId"`
SinceID int64 `json:"sinceId"`
ParticipantsNamesWithoutUsername []string `json:"participantsNamesWithoutUsername"`
Discussion []*ChatHistorySummarizationOutputsDiscussion `json:"discussion"`
Conclusion string `json:"conclusion"`
}

var ChatHistorySummarizationPrompt = lo.Must(template.New(uuid.New().String()).Parse("" +
`你是我的聊天记录总结和回顾助理。我将为你提供一份不完整的、在过去一个小时中的、包含了人物名称、人物用户名、消息发送时间、消息内容等信息的聊天记录,这些聊天记录条目每条一行,我需要你总结这些聊天记录,并在有结论的时候提供结论总结。
请你使用下面的 JSON 格式进行输出,不需要提供额外的解释和说明。为了方便理解,下面的 JSON 中 sinceMsgId 代表了话题开始的消息 ID,而 criticalMsgIds 代表了讨论过程中出现的「关键消息」,因此你不需要罗列出所有的相关 criticalMsgIds。
输出时所使用的 JSON 格式:"""
[
{ "topicName": "..", "sinceMsgId": 123456789, "participantsNamesWithoutUsername": [ "..", ".." ], "discussion": [ { "point": "..", "criticalMsgIds": [ 123456789, 123456789 ] }, { "point": "..", "criticalMsgIds": [ 123456789, 123456789 ] } ], "conclusion": ".." },
{ "topicName": "..", "sinceMsgId": 123456789, "participantsNamesWithoutUsername": [ "..", ".." ], "discussion": [ { "point": "..", "criticalMsgIds": [ 123456789, 123456789 ] }, { "point": "..", "criticalMsgIds": [ 123456789, 123456789 ] } ], "conclusion": ".." }
]
`聊天记录:"""
{{ .ChatHistory }}
"""
聊天记录:"""
{{ .ChatHistory }}
"""`))
你是我的聊天记录总结和回顾助理。以上是一份聊天记录,每条消息以 msgId 开头,请总结这些聊天记录为1~5个话题,每个话题需包含以下字段 sinceId(话题开始的 msgId)、criticalIds(讨论过程中的关键 msgId,最多5条)和 conclusion(结论,若无明确结论则该字段为空)。请使用以下 JSON 格式输出,无需额外解释说明:"""
[{"topicName":"..","sinceId":123456789,"participantsNamesWithoutUsername":[".."],"discussion":[{"point":"..","criticalIds":[123456789]}],"conclusion":".."}]"""`))

0 comments on commit bd444e6

Please sign in to comment.