diff --git a/bot.go b/bot.go index fc24e8e..d817426 100644 --- a/bot.go +++ b/bot.go @@ -17,7 +17,7 @@ type Bot struct { UUIDCallback func(uuid string) // 获取UUID的回调函数 SyncCheckCallback func(resp SyncCheckResponse) // 心跳回调 MessageHandler MessageHandler // 获取消息成功的handle - MessageErrorHandler func(err error) bool // 获取消息发生错误的handle, 返回true则尝试继续监听 + MessageErrorHandler MessageErrorHandler // 获取消息发生错误的handle, 返回err == nil 则尝试继续监听 Serializer Serializer // 序列化器, 默认为json Caller *Caller Storage *Session @@ -102,7 +102,7 @@ func (b *Bot) Logout() error { if err := b.Caller.Logout(info); err != nil { return err } - b.Exit() + b.ExitWith(ErrUserLogout) return nil } return errors.New("user not login") @@ -171,18 +171,15 @@ func (b *Bot) webInit() error { go func() { if b.MessageErrorHandler == nil { - b.MessageErrorHandler = defaultSyncCheckErrHandler(b) + b.MessageErrorHandler = defaultMessageErrorHandler } for { - err = b.syncCheck() - if err == nil { - continue - } - // 判断是否继续, 如果不继续则退出 - if goon := b.MessageErrorHandler(err); !goon { - b.err = err - b.Exit() - break + if err = b.syncCheck(); err != nil { + // 判断是否继续, 如果不继续则退出 + if err = b.MessageErrorHandler(err); err != nil { + b.ExitWith(err) + return + } } } }() @@ -292,6 +289,12 @@ func (b *Bot) Exit() { } } +// ExitWith 主动退出并且设置退出原因, 可以通过 `CrashReason` 获取退出原因 +func (b *Bot) ExitWith(err error) { + b.err = err + b.Exit() +} + // CrashReason 获取当前Bot崩溃的原因 func (b *Bot) CrashReason() error { return b.err @@ -400,21 +403,6 @@ func DefaultBot(prepares ...BotPreparer) *Bot { return bot } -// defaultSyncCheckErrHandler 默认的SyncCheck错误处理函数 -func defaultSyncCheckErrHandler(bot *Bot) func(error) bool { - return func(err error) bool { - var ret Ret - if errors.As(err, &ret) { - switch ret { - case failedLoginCheck, cookieInvalid, failedLoginWarn: - _ = bot.Logout() - return false - } - } - return true - } -} - // GetQrcodeUrl 通过uuid获取登录二维码的url func GetQrcodeUrl(uuid string) string { return qrcode + uuid diff --git a/errors.go b/errors.go index 036516a..b41e01f 100644 --- a/errors.go +++ b/errors.go @@ -35,6 +35,9 @@ var ( // ErrWebWxDataTicketNotFound define webwx_data_ticket not found error ErrWebWxDataTicketNotFound = errors.New("webwx_data_ticket not found") + + // ErrUserLogout define user logout error + ErrUserLogout = errors.New("user logout") ) // Error impl error interface diff --git a/message_handle.go b/message_handle.go index 95d049b..fe6d4b7 100644 --- a/message_handle.go +++ b/message_handle.go @@ -1,6 +1,9 @@ package openwechat -import "strings" +import ( + "errors" + "strings" +) // MessageHandler 消息处理函数 type MessageHandler func(msg *Message) @@ -283,3 +286,20 @@ func SenderNickNameContainsMatchFunc(nickname string) MatchFunc { func SenderRemakeNameContainsFunc(remakeName string) MatchFunc { return SenderMatchFunc(func(user *User) bool { return strings.Contains(user.RemarkName, remakeName) }) } + +// MessageErrorHandler 获取消息时发生了错误的处理函数 +// 参数err为获取消息时发生的错误,返回值为处理后的错误 +// 如果返回nil,则表示忽略该错误,否则将继续传递该错误 +type MessageErrorHandler func(err error) error + +// defaultMessageErrorHandler 默认的SyncCheck错误处理函数 +func defaultMessageErrorHandler(err error) error { + var ret Ret + if errors.As(err, &ret) { + switch ret { + case failedLoginCheck, cookieInvalid, failedLoginWarn: + return ret + } + } + return nil +}