Skip to content

Commit

Permalink
refactor: message content use struct
Browse files Browse the repository at this point in the history
  • Loading branch information
liushuangls committed Mar 14, 2024
1 parent 492fee4 commit a2a97c5
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 43 deletions.
28 changes: 11 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func main() {
resp, err := client.CreateMessages(context.Background(), anthropic.MessagesRequest{
Model: anthropic.ModelClaudeInstant1Dot2,
Messages: []anthropic.Message{
{Role: anthropic.RoleUser, Content: "What is your name?"},
anthropic.NewUserTextMessage("What is your name?"),
},
MaxTokens: 1000,
})
Expand Down Expand Up @@ -75,7 +75,7 @@ func main() {
MessagesRequest: anthropic.MessagesRequest{
Model: anthropic.ModelClaudeInstant1Dot2,
Messages: []anthropic.Message{
{Role: anthropic.RoleUser, Content: "What is your name?"},
anthropic.NewUserTextMessage("What is your name?"),
},
MaxTokens: 1000,
},
Expand Down Expand Up @@ -121,25 +121,19 @@ func main() {
if err != nil {
panic(err)
}

resp, err := client.CreateMessages(context.Background(), anthropic.MessagesRequest{
Model: anthropic.ModelClaude3Opus20240229, // only claude 3 model can use vision
Model: anthropic.ModelClaude3Opus20240229,
Messages: []anthropic.Message{
{
Role: anthropic.RoleUser,
Content: []any{
anthropic.MessageImageContent{
Type: "image",
Source: anthropic.MessageImageContentSource{
Type: "base64",
MediaType: imageMediaType,
Data: imageData,
},
},
anthropic.MessageTextContent{
Type: "text",
Text: "Describe this image.",
},
Content: []anthropic.MessageContent{
anthropic.NewImageMessageContent(anthropic.MessageContentImageSource{
Type: "base64",
MediaType: imageMediaType,
Data: imageData,
}),
anthropic.NewTextMessageContent("Describe this image."),
},
},
},
Expand Down
64 changes: 55 additions & 9 deletions message.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,67 @@ func (m *MessagesRequest) SetTopK(k int) {
}

type Message struct {
Role string `json:"role"`
Content any `json:"content"` // Content can be string, MessageTextContent or MessageImageContent or slice
Role string `json:"role"`
Content []MessageContent `json:"content"`
}

type MessageTextContent struct {
Type string `json:"type"`
Text string `json:"text"`
func NewUserTextMessage(text string) Message {
return Message{
Role: "user",
Content: []MessageContent{NewTextMessageContent(text)},
}
}

func NewAssistantTextMessage(text string) Message {
return Message{
Role: "assistant",
Content: []MessageContent{NewTextMessageContent(text)},
}
}

func (m Message) GetFirstContent() MessageContent {
if len(m.Content) == 0 {
return MessageContent{}
}
return m.Content[0]
}

type MessageContent struct {
Type string `json:"type"`
Text *string `json:"text,omitempty"`
Source *MessageContentImageSource `json:"source,omitempty"`
}

type MessageImageContent struct {
Type string `json:"type"`
Source MessageImageContentSource `json:"source"`
func NewTextMessageContent(text string) MessageContent {
return MessageContent{
Type: "text",
Text: &text,
}
}

func NewImageMessageContent(source MessageContentImageSource) MessageContent {
return MessageContent{
Type: "image",
Source: &source,
}
}

func (m MessageContent) IsTextContent() bool {
return m.Type == "text"
}

func (m MessageContent) IsImageContent() bool {
return m.Type == "image"
}

func (m MessageContent) GetText() string {
if m.IsTextContent() && m.Text != nil {
return *m.Text
}
return ""
}

type MessageImageContentSource struct {
type MessageContentImageSource struct {
Type string `json:"type"`
MediaType string `json:"media_type"`
Data any `json:"data"`
Expand Down
4 changes: 2 additions & 2 deletions message_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func TestMessagesStream(t *testing.T) {
MessagesRequest: anthropic.MessagesRequest{
Model: anthropic.ModelClaudeInstant1Dot2,
Messages: []anthropic.Message{
{Role: anthropic.RoleUser, Content: "What is your name?"},
anthropic.NewUserTextMessage("What is your name?"),
},
MaxTokens: 1000,
},
Expand Down Expand Up @@ -81,7 +81,7 @@ func TestMessagesStreamError(t *testing.T) {
MessagesRequest: anthropic.MessagesRequest{
Model: anthropic.ModelClaudeInstant1Dot2,
Messages: []anthropic.Message{
{Role: anthropic.RoleUser, Content: "What is your name?"},
anthropic.NewUserTextMessage("What is your name?"),
},
MaxTokens: 1000,
},
Expand Down
24 changes: 9 additions & 15 deletions message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func TestMessages(t *testing.T) {
resp, err := client.CreateMessages(context.Background(), anthropic.MessagesRequest{
Model: anthropic.ModelClaudeInstant1Dot2,
Messages: []anthropic.Message{
{Role: anthropic.RoleUser, Content: "What is your name?"},
anthropic.NewUserTextMessage("What is your name?"),
},
MaxTokens: 1000,
})
Expand All @@ -65,7 +65,7 @@ func TestMessagesTokenError(t *testing.T) {
_, err := client.CreateMessages(context.Background(), anthropic.MessagesRequest{
Model: anthropic.ModelClaudeInstant1Dot2,
Messages: []anthropic.Message{
{Role: anthropic.RoleUser, Content: "What is your name?"},
anthropic.NewUserTextMessage("What is your name?"),
},
MaxTokens: 1000,
})
Expand Down Expand Up @@ -109,19 +109,13 @@ func TestMessagesVision(t *testing.T) {
Messages: []anthropic.Message{
{
Role: anthropic.RoleUser,
Content: []any{
anthropic.MessageImageContent{
Type: "image",
Source: anthropic.MessageImageContentSource{
Type: "base64",
MediaType: imageMediaType,
Data: imageData,
},
},
anthropic.MessageTextContent{
Type: "text",
Text: "Describe this image.",
},
Content: []anthropic.MessageContent{
anthropic.NewImageMessageContent(anthropic.MessageContentImageSource{
Type: "base64",
MediaType: imageMediaType,
Data: imageData,
}),
anthropic.NewTextMessageContent("Describe this image."),
},
},
},
Expand Down

0 comments on commit a2a97c5

Please sign in to comment.