diff --git a/packages/sdk/server-ai/__tests__/LDAIClientImpl.test.ts b/packages/sdk/server-ai/__tests__/LDAIClientImpl.test.ts index a396c2c1a..d586b7c0c 100644 --- a/packages/sdk/server-ai/__tests__/LDAIClientImpl.test.ts +++ b/packages/sdk/server-ai/__tests__/LDAIClientImpl.test.ts @@ -1,6 +1,6 @@ import { LDContext } from '@launchdarkly/js-server-sdk-common'; -import { LDGenerationConfig } from '../src/api/config'; +import { LDAIDefaults } from '../src/api/config'; import { LDAIClientImpl } from '../src/LDAIClientImpl'; import { LDClientMin } from '../src/LDClientMin'; @@ -32,13 +32,14 @@ it('handles empty variables in template interpolation', () => { it('returns model config with interpolated prompts', async () => { const client = new LDAIClientImpl(mockLdClient); const key = 'test-flag'; - const defaultValue: LDGenerationConfig = { + const defaultValue: LDAIDefaults = { model: { modelId: 'test', name: 'test-model' }, prompt: [], + enabled: true, }; const mockVariation = { - model: { modelId: 'example-provider', name: 'imagination' }, + model: { modelId: 'example-provider', name: 'imagination', temperature: 0.7, maxTokens: 4096 }, prompt: [ { role: 'system', content: 'Hello {{name}}' }, { role: 'user', content: 'Score: {{score}}' }, @@ -55,13 +56,11 @@ it('returns model config with interpolated prompts', async () => { const result = await client.modelConfig(key, testContext, defaultValue, variables); expect(result).toEqual({ - config: { - model: { modelId: 'example-provider', name: 'imagination' }, - prompt: [ - { role: 'system', content: 'Hello John' }, - { role: 'user', content: 'Score: 42' }, - ], - }, + model: { modelId: 'example-provider', name: 'imagination', temperature: 0.7, maxTokens: 4096 }, + prompt: [ + { role: 'system', content: 'Hello John' }, + { role: 'user', content: 'Score: 42' }, + ], tracker: expect.any(Object), enabled: true, }); @@ -70,7 +69,7 @@ it('returns model config with interpolated prompts', async () => { it('includes context in variables for prompt interpolation', async () => { const client = new LDAIClientImpl(mockLdClient); const key = 'test-flag'; - const defaultValue: LDGenerationConfig = { + const defaultValue: LDAIDefaults = { model: { modelId: 'test', name: 'test-model' }, prompt: [], }; @@ -84,13 +83,13 @@ it('includes context in variables for prompt interpolation', async () => { const result = await client.modelConfig(key, testContext, defaultValue); - expect(result.config.prompt?.[0].content).toBe('User key: test-user'); + expect(result.prompt?.[0].content).toBe('User key: test-user'); }); it('handles missing metadata in variation', async () => { const client = new LDAIClientImpl(mockLdClient); const key = 'test-flag'; - const defaultValue: LDGenerationConfig = { + const defaultValue: LDAIDefaults = { model: { modelId: 'test', name: 'test-model' }, prompt: [], }; @@ -105,10 +104,8 @@ it('handles missing metadata in variation', async () => { const result = await client.modelConfig(key, testContext, defaultValue); expect(result).toEqual({ - config: { - model: { modelId: 'example-provider', name: 'imagination' }, - prompt: [{ role: 'system', content: 'Hello' }], - }, + model: { modelId: 'example-provider', name: 'imagination' }, + prompt: [{ role: 'system', content: 'Hello' }], tracker: expect.any(Object), enabled: false, }); @@ -117,9 +114,10 @@ it('handles missing metadata in variation', async () => { it('passes the default value to the underlying client', async () => { const client = new LDAIClientImpl(mockLdClient); const key = 'non-existent-flag'; - const defaultValue: LDGenerationConfig = { + const defaultValue: LDAIDefaults = { model: { modelId: 'default-model', name: 'default' }, prompt: [{ role: 'system', content: 'Default prompt' }], + enabled: true, }; mockLdClient.variation.mockResolvedValue(defaultValue); @@ -127,7 +125,8 @@ it('passes the default value to the underlying client', async () => { const result = await client.modelConfig(key, testContext, defaultValue); expect(result).toEqual({ - config: defaultValue, + model: defaultValue.model, + prompt: defaultValue.prompt, tracker: expect.any(Object), enabled: false, }); diff --git a/packages/sdk/server-ai/examples/bedrock/src/index.ts b/packages/sdk/server-ai/examples/bedrock/src/index.ts index 95bd83da6..39fd3fb03 100644 --- a/packages/sdk/server-ai/examples/bedrock/src/index.ts +++ b/packages/sdk/server-ai/examples/bedrock/src/index.ts @@ -66,8 +66,12 @@ async function main() { const completion = tracker.trackBedrockConverse( await awsClient.send( new ConverseCommand({ - modelId: aiConfig.config.model?.modelId ?? 'no-model', - messages: mapPromptToConversation(aiConfig.config.prompt ?? []), + modelId: aiConfig.model?.modelId ?? 'no-model', + messages: mapPromptToConversation(aiConfig.prompt ?? []), + inferenceConfig: { + temperature: aiConfig.model?.temperature ?? 0.5, + maxTokens: aiConfig.model?.maxTokens ?? 4096, + }, }), ), ); diff --git a/packages/sdk/server-ai/examples/openai/src/index.ts b/packages/sdk/server-ai/examples/openai/src/index.ts index 041d1f076..abdae7612 100644 --- a/packages/sdk/server-ai/examples/openai/src/index.ts +++ b/packages/sdk/server-ai/examples/openai/src/index.ts @@ -60,8 +60,10 @@ async function main(): Promise { const { tracker } = aiConfig; const completion = await tracker.trackOpenAI(async () => client.chat.completions.create({ - messages: aiConfig.config.prompt || [], - model: aiConfig.config.model?.modelId || 'gpt-4', + messages: aiConfig.prompt || [], + model: aiConfig.model?.modelId || 'gpt-4', + temperature: aiConfig.model?.temperature ?? 0.5, + max_tokens: aiConfig.model?.maxTokens ?? 4096, }), ); diff --git a/packages/sdk/server-ai/src/LDAIClientImpl.ts b/packages/sdk/server-ai/src/LDAIClientImpl.ts index dc87d270e..7cf7b6f2d 100644 --- a/packages/sdk/server-ai/src/LDAIClientImpl.ts +++ b/packages/sdk/server-ai/src/LDAIClientImpl.ts @@ -2,7 +2,7 @@ import * as Mustache from 'mustache'; import { LDContext } from '@launchdarkly/js-server-sdk-common'; -import { LDAIConfig, LDGenerationConfig, LDMessage, LDModelConfig } from './api/config'; +import { LDAIConfig, LDAIDefaults, LDMessage, LDModelConfig } from './api/config'; import { LDAIClient } from './api/LDAIClient'; import { LDAIConfigTrackerImpl } from './LDAIConfigTrackerImpl'; import { LDClientMin } from './LDClientMin'; @@ -32,16 +32,28 @@ export class LDAIClientImpl implements LDAIClient { return Mustache.render(template, variables, undefined, { escape: (item: any) => item }); } - async modelConfig( + async modelConfig( key: string, context: LDContext, - defaultValue: TDefault, + defaultValue: LDAIDefaults, variables?: Record, ): Promise { const value: VariationContent = await this._ldClient.variation(key, context, defaultValue); + const tracker = new LDAIConfigTrackerImpl( + this._ldClient, + key, + // eslint-disable-next-line no-underscore-dangle + value._ldMeta?.versionKey ?? '', + context, + ); + // eslint-disable-next-line no-underscore-dangle + const enabled = !!value._ldMeta?.enabled; + const config: LDAIConfig = { + tracker, + enabled, + }; // We are going to modify the contents before returning them, so we make a copy. // This isn't a deep copy and the application developer should not modify the returned content. - const config: LDGenerationConfig = {}; if (value.model) { config.model = { ...value.model }; } @@ -54,18 +66,6 @@ export class LDAIClientImpl implements LDAIClient { })); } - return { - config, - // eslint-disable-next-line no-underscore-dangle - tracker: new LDAIConfigTrackerImpl( - this._ldClient, - key, - // eslint-disable-next-line no-underscore-dangle - value._ldMeta?.versionKey ?? '', - context, - ), - // eslint-disable-next-line no-underscore-dangle - enabled: !!value._ldMeta?.enabled, - }; + return config; } } diff --git a/packages/sdk/server-ai/src/api/LDAIClient.ts b/packages/sdk/server-ai/src/api/LDAIClient.ts index cffd657a7..d716bc230 100644 --- a/packages/sdk/server-ai/src/api/LDAIClient.ts +++ b/packages/sdk/server-ai/src/api/LDAIClient.ts @@ -1,16 +1,6 @@ import { LDContext } from '@launchdarkly/js-server-sdk-common'; -import { LDAIConfig, LDGenerationConfig } from './config/LDAIConfig'; - -/** - * Interface for default model configuration. - */ -export interface LDAIDefaults extends LDGenerationConfig { - /** - * Whether the configuration is enabled. - */ - enabled?: boolean; -} +import { LDAIConfig, LDAIDefaults } from './config/LDAIConfig'; /** * Interface for performing AI operations using LaunchDarkly. @@ -77,10 +67,10 @@ export interface LDAIClient { * } * ``` */ - modelConfig( + modelConfig( key: string, context: LDContext, - defaultValue: TDefault, + defaultValue: LDAIDefaults, variables?: Record, ): Promise; } diff --git a/packages/sdk/server-ai/src/api/config/LDAIConfig.ts b/packages/sdk/server-ai/src/api/config/LDAIConfig.ts index 432a0a732..2dfa4baee 100644 --- a/packages/sdk/server-ai/src/api/config/LDAIConfig.ts +++ b/packages/sdk/server-ai/src/api/config/LDAIConfig.ts @@ -7,7 +7,7 @@ export interface LDModelConfig { /** * The ID of the model. */ - modelId?: string; + modelId: string; /** * Tuning parameter for randomness versus determinism. Exact effect will be determined by the @@ -41,9 +41,9 @@ export interface LDMessage { } /** - * Configuration which affects generation. + * AI configuration and tracker. */ -export interface LDGenerationConfig { +export interface LDAIConfig { /** * Optional model configuration. */ @@ -52,16 +52,6 @@ export interface LDGenerationConfig { * Optional prompt data. */ prompt?: LDMessage[]; -} - -/** - * AI Config value and tracker. - */ -export interface LDAIConfig { - /** - * The result of the AI Config customization. - */ - config: LDGenerationConfig; /** * A tracker which can be used to generate analytics. @@ -73,3 +63,16 @@ export interface LDAIConfig { */ enabled: boolean; } + +/** + * Default value for a `modelConfig`. This is the same as the LDAIConfig, but it does not include + * a tracker and `enabled` is optional. + */ +export type LDAIDefaults = Omit & { + /** + * Whether the configuration is enabled. + * + * defaults to false + */ + enabled?: boolean; +};