Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open ai wrapper #1371

Open
wants to merge 11 commits into
base: development
Choose a base branch
from
189 changes: 189 additions & 0 deletions pkg/gofr/datasource/openai/chatcompletion.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
package openai

import (
"context"
"encoding/json"
"errors"
"fmt"
"time"

"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)

const CompletionsEndpoint = "/v1/chat/completions"

type CreateCompletionsRequest struct {
Messages []Message `json:"messages,omitempty"`
Model string `json:"model,omitempty"`
Store bool `json:"store,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"`
MetaData interface{} `json:"metadata,omitempty"` // object or null
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
LogitBias map[string]string `json:"logit_bias,omitempty"`
LogProbs int `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"` // deprecated
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
N int `json:"n,omitempty"`
Modalities []string `json:"modalities,omitempty"`
Prediction interface{} `json:"prediction,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`

Audio struct {
Voice string `json:"voice,omitempty"`
Format string `json:"format,omitempty"`
} `json:"audio,omitempty"`

ResponseFormat interface{} `json:"response_format,omitempty"`
Seed int `json:"seed,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
Stop interface{} `json:"stop,omitempty"`
Stream bool `json:"stream,omitempty"`

StreamOptions struct {
IncludeUsage bool `json:"include_usage,omitempty"`
} `json:"stram_options,omitempty"`

Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`

Tools []struct {
Type string `json:"type,omitempty"`
Function struct {
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
Parameters interface{} `json:"parameters,omitempty"`
Strict bool `json:"strict,omitempty"`
} `json:"function,omitempty"`
} `json:"tools,omitempty"`

ToolChoice interface{} `json:"tool_choice,omitempty"`
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
Suffix string `json:"suffix,omitempty"`
User string `json:"user,omitempty"`
}

type Message struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
Name string `json:"name,omitempty"`
}

type CreateCompletionsResponse struct {
ID string `json:"id,omitempty"`
Object string `json:"object,omitempty"`
Created int `json:"created,omitempty"`
Model string `json:"model,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
SystemFingerprint string `json:"system_fingerprint,omitempty"`

Choices []struct {
Index int `json:"index,omitempty"`

Message struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
Refusal string `json:"refusal,omitempty"`
ToolCalls interface{} `json:"tool_calls,omitempty"`
} `json:"message"`

Logprobs interface{} `json:"logprobs,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
} `json:"choices,omitempty"`

Usage Usage `json:"usage,omitempty"`

Error *Error `json:"error,omitempty"`
}

type Usage struct {
PromptTokens int `json:"prompt_tokens,omitempty"`
CompletionTokens int `json:"completion_tokens,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"`
CompletionTokensDetails interface{} `json:"completion_tokens_details,omitempty"`
PromptTokensDetails interface{} `json:"prompt_tokens_details,omitempty"`
}

type Error struct {
Message string `json:"message,omitempty"`
Type string `json:"type,omitempty"`
Param interface{} `json:"param,omitempty"`
Code interface{} `json:"code,omitempty"`
yash-sojitra marked this conversation as resolved.
Show resolved Hide resolved
}

var (
ErrMissingBoth = errors.New("both messages and model fields not provided")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yash-sojitra Why have we exported these errors? Do we intend that users should use it? If not, can we make them unexported ?

ErrMissingMessages = errors.New("messages fields not provided")
ErrMissingModel = errors.New("model fields not provided")
)

func (e *Error) Error() string {
return fmt.Sprintf("%s: %s", e.Code, e.Message)
}

func (c *Client) CreateCompletionsRaw(ctx context.Context, r *CreateCompletionsRequest) ([]byte, error) {
return c.Post(ctx, CompletionsEndpoint, r)
}

func (c *Client) CreateCompletions(ctx context.Context, r *CreateCompletionsRequest) (response *CreateCompletionsResponse, err error) {
tracerCtx, span := c.AddTrace(ctx, "CreateCompletions")
startTime := time.Now()

if r.Messages == nil && r.Model == "" {
c.logger.Errorf("%v", ErrMissingBoth)
return nil, ErrMissingBoth
}

if r.Messages == nil {
c.logger.Errorf("%v", ErrMissingMessages)
return nil, ErrMissingMessages
}

if r.Model == "" {
c.logger.Errorf("%v", ErrMissingModel)
return nil, ErrMissingModel
}

raw, err := c.CreateCompletionsRaw(tracerCtx, r)
if err != nil {
return response, err
}

err = json.Unmarshal(raw, &response)
if err != nil {
return nil, err
}

ql := &APILog{
ID: response.ID,
Object: response.Object,
Created: response.Created,
Model: response.Model,
ServiceTier: response.ServiceTier,
SystemFingerprint: response.SystemFingerprint,
Usage: response.Usage,
Error: response.Error,
}

c.SendChatCompletionOperationStats(ctx, ql, startTime, "ChatCompletion", span)

return response, err
}

func (c *Client) SendChatCompletionOperationStats(ctx context.Context, ql *APILog, startTime time.Time, method string, span trace.Span) {
duration := time.Since(startTime).Microseconds()

ql.Duration = duration

c.logger.Debug(ql)

c.metrics.RecordHistogram(ctx, "openai_api_request_duration", float64(duration))
c.metrics.RecordRequestCount(ctx, "openai_api_total_request_count")
c.metrics.RecordTokenUsage(ctx, "openai_api_token_usage", ql.Usage.PromptTokens, ql.Usage.CompletionTokens)

if span != nil {
defer span.End()
span.SetAttributes(attribute.Int64(fmt.Sprintf("openai.%v.duration", method), duration))
}
}
141 changes: 141 additions & 0 deletions pkg/gofr/datasource/openai/chatcompletion_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package openai

import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
)

//nolint:funlen // Function length is intentional due to complexity
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yash-sojitra I think we can remove this nolint and instead break the test method into separate tests. That will make it easier to read.

func Test_ChatCompletions(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockLogger := NewMockLogger(ctrl)
mockMetrics := NewMockMetrics(ctrl)

tests := []struct {
name string
request *CreateCompletionsRequest
response *CreateCompletionsResponse
expectedError error
setupMocks func(*MockLogger, *MockMetrics)
}{
{
name: "successful completion request",
request: &CreateCompletionsRequest{
Messages: []Message{{Role: "user", Content: "Hello"}},
Model: "gpt-3.5-turbo",
},
response: &CreateCompletionsResponse{
ID: "test-id",
Object: "chat.completion",
Created: 1234567890,
Usage: Usage{
PromptTokens: 10,
CompletionTokens: 20,
TotalTokens: 30,
},
},
expectedError: nil,
setupMocks: func(logger *MockLogger, metrics *MockMetrics) {
metrics.EXPECT().RecordHistogram(gomock.Any(), "openai_api_request_duration", gomock.Any())
metrics.EXPECT().RecordRequestCount(gomock.Any(), "openai_api_total_request_count")
metrics.EXPECT().RecordTokenUsage(gomock.Any(), "openai_api_token_usage", 10, 20)
logger.EXPECT().Debug(gomock.Any())
},
},
{
name: "missing both messages and model",
request: &CreateCompletionsRequest{},
expectedError: ErrMissingBoth,
setupMocks: func(logger *MockLogger, _ *MockMetrics) {
logger.EXPECT().Errorf("%v", ErrMissingBoth)
},
},
{
name: "missing messages",
request: &CreateCompletionsRequest{
Model: "gpt-3.5-turbo",
},
expectedError: ErrMissingMessages,
setupMocks: func(logger *MockLogger, _ *MockMetrics) {
logger.EXPECT().Errorf("%v", ErrMissingMessages)
},
},
{
name: "missing model",
request: &CreateCompletionsRequest{
Messages: []Message{{Role: "user", Content: "Hello"}},
},
expectedError: ErrMissingModel,
setupMocks: func(logger *MockLogger, _ *MockMetrics) {
logger.EXPECT().Errorf("%v", ErrMissingModel)
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var serverURL string

var server *httptest.Server

if tt.response != nil {
server = setupTestServer(t, CompletionsEndpoint, tt.response)
defer server.Close()
serverURL = server.URL
}

client := &Client{
config: &Config{
APIKey: "test-api-key",
BaseURL: serverURL,
},
httpClient: http.DefaultClient,
logger: mockLogger,
metrics: mockMetrics,
}

tt.setupMocks(mockLogger, mockMetrics)

response, err := client.CreateCompletions(context.Background(), tt.request)

if tt.expectedError != nil {
require.ErrorIs(t, err, tt.expectedError)
assert.Nil(t, response)
} else {
require.NoError(t, err)
assert.NotNil(t, response)
}
})
}
}

func setupTestServer(t *testing.T, path string, response interface{}) *httptest.Server {
t.Helper()

server := httptest.NewServer(
http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, path, r.URL.Path)
assert.Equal(t, "Bearer test-api-key", r.Header.Get("Authorization"))
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))

w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(response)

if err != nil {
t.Error(err)
return
}
yash-sojitra marked this conversation as resolved.
Show resolved Hide resolved
}))

return server
}
Loading