Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open ai wrapper #1371

Open
wants to merge 10 commits into
base: development
Choose a base branch
from
184 changes: 184 additions & 0 deletions pkg/gofr/datasource/openai/chatcompletion.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
package openai

import (
"context"
"encoding/json"
"errors"
"fmt"
"time"

"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)

const CompletionsEndpoint = "/v1/chat/completions"

type CreateCompletionsRequest struct {
Messages []Message `json:"messages,omitempty"`
Model string `json:"model,omitempty"`
Store bool `json:"store,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"`
MetaData interface{} `json:"metadata,omitempty"` // object or null
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
LogitBias map[string]string `json:"logit_bias,omitempty"`
LogProbs int `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"` // deprecated
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
N int `json:"n,omitempty"`
Modalities []string `json:"modalities,omitempty"`
Prediction interface{} `json:"prediction,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`

Audio struct {
Voice string `json:"voice,omitempty"`
Format string `json:"format,omitempty"`
} `json:"audio,omitempty"`

ResposneFormat interface{} `json:"response_format,omitempty"`
yash-sojitra marked this conversation as resolved.
Show resolved Hide resolved
Seed int `json:"seed,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
Stop interface{} `json:"stop,omitempty"`
Stream bool `json:"stream,omitempty"`

StreamOptions struct {
IncludeUsage bool `json:"include_usage,omitempty"`
} `json:"stram_options,omitempty"`

Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`

Tools []struct {
Type string `json:"type,omitempty"`
Function struct {
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
Parameters interface{} `json:"parameters,omitempty"`
Strict bool `json:"strict,omitempty"`
} `json:"function,omitempty"`
} `json:"tools,omitempty"`

ToolChoice interface{} `json:"toolChoice,omitempty"`
yash-sojitra marked this conversation as resolved.
Show resolved Hide resolved
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
Suffix string `json:"suffix,omitempty"`
User string `json:"user,omitempty"`
}

type Message struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
Name string `json:"name,omitempty"`
}

type CreateCompletionsResponse struct {
ID string `json:"id,omitempty"`
Object string `json:"object,omitempty"`
Created int `json:"created,omitempty"`
Model string `json:"model,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
SystemFingerprint string `json:"system_fingerprint,omitempty"`

Choices []struct {
Index int `json:"index,omitempty"`

Message struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
Refusal string `json:"refusal,omitempty"`
ToolCalls interface{} `json:"tool_calls,omitempty"`
} `json:"message"`

Logprobs interface{} `json:"logprobs,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
} `json:"choices,omitempty"`

Usage struct {
PromptTokens int `json:"prompt_tokens,omitempty"`
CompletionTokens int `json:"completion_tokens,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"`
CompletionTokelDetails interface{} `json:"completion_tokens_details,omitempty"`
PromptTokenDetails interface{} `json:"prompt_tokens_details,omitempty"`
} `json:"usage,omitempty"`

Error *Error `json:"error,omitempty"`
}

type Error struct {
Message string `json:"message,omitempty"`
Type string `json:"type,omitempty"`
Param interface{} `json:"param,omitempty"`
Code interface{} `json:"code,omitempty"`
yash-sojitra marked this conversation as resolved.
Show resolved Hide resolved
}

var (
ErrMissingBoth = errors.New("both messages and model fields not provided")
ErrMissingMessages = errors.New("messages fields not provided")
ErrMissingModel = errors.New("model fields not provided")
)

func (e *Error) Error() string {
return fmt.Sprintf("%s: %s", e.Code, e.Message)
}

func (c *Client) CreateCompletionsRaw(ctx context.Context, r *CreateCompletionsRequest) ([]byte, error) {
return c.Post(ctx, CompletionsEndpoint, r)
}

func (c *Client) CreateCompletions(ctx context.Context, r *CreateCompletionsRequest) (response *CreateCompletionsResponse, err error) {
tracerCtx, span := c.AddTrace(ctx, "CreateCompletions")
startTime := time.Now()

if r.Messages == nil && r.Model == "" {
c.logger.Errorf("%v", ErrMissingBoth)
return nil, ErrMissingBoth
}

if r.Messages == nil {
c.logger.Errorf("%v", ErrMissingMessages)
return nil, ErrMissingMessages
}

if r.Model == "" {
c.logger.Errorf("%v", ErrMissingModel)
return nil, ErrMissingModel
}

raw, err := c.CreateCompletionsRaw(tracerCtx, r)
if err != nil {
return response, err
}

err = json.Unmarshal(raw, &response)

ql := &OpenAiAPILog{
ID: response.ID,
Object: response.Object,
Created: response.Created,
Model: response.Model,
ServiceTier: response.ServiceTier,
SystemFingerprint: response.SystemFingerprint,
Usage: response.Usage,
Error: response.Error,
}

c.SendChatCompletionOperationStats(ql, startTime, "ChatCompletion", span)

return response, err
}

func (c *Client) SendChatCompletionOperationStats(ql *OpenAiAPILog, startTime time.Time, method string, span trace.Span) {
duration := time.Since(startTime).Microseconds()

ql.Duration = duration

c.logger.Debug(ql)

c.metrics.RecordHistogram(context.Background(), "openai_api_request_duration", float64(duration))
c.metrics.RecordRequestCount(context.Background(), "openai_api_total_request_count")
c.metrics.RecordTokenUsage(context.Background(), "openai_api_token_usage", ql.Usage.PromptTokens, ql.Usage.CompletionTokens)

if span != nil {
defer span.End()
span.SetAttributes(attribute.Int64(fmt.Sprintf("openai.%v.duration", method), duration))
}
}
147 changes: 147 additions & 0 deletions pkg/gofr/datasource/openai/chatcompletion_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package openai

import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
)

//nolint:funlen // Function length is intentional due to complexity
func Test_ChatCompletions(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockLogger := NewMockLogger(ctrl)
mockMetrics := NewMockMetrics(ctrl)

tests := []struct {
name string
request *CreateCompletionsRequest
response *CreateCompletionsResponse
expectedError error
setupMocks func(*MockLogger, *MockMetrics)
}{
{
name: "successful completion request",
request: &CreateCompletionsRequest{
Messages: []Message{{Role: "user", Content: "Hello"}},
Model: "gpt-3.5-turbo",
},
response: &CreateCompletionsResponse{
ID: "test-id",
Object: "chat.completion",
Created: 1234567890,
Usage: struct {
PromptTokens int `json:"prompt_tokens,omitempty"`
CompletionTokens int `json:"completion_tokens,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"`
CompletionTokelDetails interface{} `json:"completion_tokens_details,omitempty"`
PromptTokenDetails interface{} `json:"prompt_tokens_details,omitempty"`
}{
PromptTokens: 10,
CompletionTokens: 20,
TotalTokens: 30,
},
},
expectedError: nil,
setupMocks: func(logger *MockLogger, metrics *MockMetrics) {
metrics.EXPECT().RecordHistogram(gomock.Any(), "openai_api_request_duration", gomock.Any())
metrics.EXPECT().RecordRequestCount(gomock.Any(), "openai_api_total_request_count")
metrics.EXPECT().RecordTokenUsage(gomock.Any(), "openai_api_token_usage", 10, 20)
logger.EXPECT().Debug(gomock.Any())
},
},
{
name: "missing both messages and model",
request: &CreateCompletionsRequest{},
expectedError: ErrMissingBoth,
setupMocks: func(logger *MockLogger, _ *MockMetrics) {
logger.EXPECT().Errorf("%v", ErrMissingBoth)
},
},
{
name: "missing messages",
request: &CreateCompletionsRequest{
Model: "gpt-3.5-turbo",
},
expectedError: ErrMissingMessages,
setupMocks: func(logger *MockLogger, _ *MockMetrics) {
logger.EXPECT().Errorf("%v", ErrMissingMessages)
},
},
{
name: "missing model",
request: &CreateCompletionsRequest{
Messages: []Message{{Role: "user", Content: "Hello"}},
},
expectedError: ErrMissingModel,
setupMocks: func(logger *MockLogger, _ *MockMetrics) {
logger.EXPECT().Errorf("%v", ErrMissingModel)
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var serverURL string

var server *httptest.Server

if tt.response != nil {
server = setupTestServer(t, CompletionsEndpoint, tt.response)
defer server.Close()
serverURL = server.URL
}

client := &Client{
config: &Config{
APIKey: "test-api-key",
BaseURL: serverURL,
},
httpClient: http.DefaultClient,
logger: mockLogger,
metrics: mockMetrics,
}

tt.setupMocks(mockLogger, mockMetrics)

response, err := client.CreateCompletions(context.Background(), tt.request)

if tt.expectedError != nil {
assert.Equal(t, tt.expectedError, err)
yash-sojitra marked this conversation as resolved.
Show resolved Hide resolved
assert.Nil(t, response)
} else {
require.NoError(t, err)
assert.NotNil(t, response)
}
})
}
}

func setupTestServer(t *testing.T, path string, response interface{}) *httptest.Server {
t.Helper()

server := httptest.NewServer(
http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, path, r.URL.Path)
assert.Equal(t, "Bearer test-api-key", r.Header.Get("Authorization"))
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))

w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(response)

if err != nil {
t.Error(err)
return
}
yash-sojitra marked this conversation as resolved.
Show resolved Hide resolved
}))

return server
}
Loading
Loading