diff --git a/.idea/externalDependencies.xml b/.idea/externalDependencies.xml index efd90ca7..98e550ba 100644 --- a/.idea/externalDependencies.xml +++ b/.idea/externalDependencies.xml @@ -15,7 +15,6 @@ - diff --git a/.idea/runConfigurations/Cohere_Connector.xml b/.idea/runConfigurations/Cohere_Connector.xml new file mode 100644 index 00000000..09ea8f44 --- /dev/null +++ b/.idea/runConfigurations/Cohere_Connector.xml @@ -0,0 +1,15 @@ + + + + + + + + + + + diff --git a/.idea/runConfigurations/go_test_github_com_basemind_ai_monorepo_services_api_gateway_internal_services.xml b/.idea/runConfigurations/go_test_github_com_basemind_ai_monorepo_services_api_gateway_internal_services.xml index e43d08cd..871b6037 100644 --- a/.idea/runConfigurations/go_test_github_com_basemind_ai_monorepo_services_api_gateway_internal_services.xml +++ b/.idea/runConfigurations/go_test_github_com_basemind_ai_monorepo_services_api_gateway_internal_services.xml @@ -5,7 +5,6 @@ type="GoTestRunConfiguration" factoryName="Go Test" > - - { + store.resetState(); router.push( setRouteParams( Navigation.ApplicationDetail, diff --git a/frontend/src/components/projects/[projectId]/applications/[applicationId]/config-create-wizard/base-form.spec.tsx b/frontend/src/components/projects/[projectId]/applications/[applicationId]/config-create-wizard/base-form.spec.tsx index 88106953..ae0bb0de 100644 --- a/frontend/src/components/projects/[projectId]/applications/[applicationId]/config-create-wizard/base-form.spec.tsx +++ b/frontend/src/components/projects/[projectId]/applications/[applicationId]/config-create-wizard/base-form.spec.tsx @@ -175,8 +175,9 @@ describe('PromptConfigBaseForm', () => { expect(setConfigName).toHaveBeenCalledWith('New Config'); }); - it('should display the model vendor select field as disabled', () => { + it('should display and allow selection of model vendor', () => { const modelVendor = ModelVendor.OpenAI; + const setVendor = vi.fn(); render( { modelVendor={modelVendor} setConfigName={vi.fn()} setModelType={vi.fn()} - setVendor={vi.fn()} + setVendor={setVendor} />, ); const select: HTMLInputElement = screen.getByTestId( @@ -194,7 +195,8 @@ describe('PromptConfigBaseForm', () => { ); expect(select).toBeInTheDocument(); expect(select.value).toBe(modelVendor); - expect(select.disabled).toBe(true); + fireEvent.change(select, { target: { value: ModelVendor.Cohere } }); + expect(setVendor).toHaveBeenCalledWith(ModelVendor.Cohere); }); it('should display and allow editing of the model type select field based on the selected model vendor', () => { diff --git a/frontend/src/components/projects/[projectId]/applications/[applicationId]/config-create-wizard/base-form.tsx b/frontend/src/components/projects/[projectId]/applications/[applicationId]/config-create-wizard/base-form.tsx index b45fa56d..6c2150d8 100644 --- a/frontend/src/components/projects/[projectId]/applications/[applicationId]/config-create-wizard/base-form.tsx +++ b/frontend/src/components/projects/[projectId]/applications/[applicationId]/config-create-wizard/base-form.tsx @@ -1,5 +1,5 @@ import { useTranslations } from 'next-intl'; -import { useMemo } from 'react'; +import { useEffect, useMemo } from 'react'; import { EntityNameInput } from '@/components/entity-name-input'; import { @@ -7,7 +7,12 @@ import { modelVendorTypeMap, UnavailableModelVendor, } from '@/constants/models'; -import { ModelType, ModelVendor } from '@/types'; +import { + CohereModelType, + ModelType, + ModelVendor, + OpenAIModelType, +} from '@/types'; import { handleChange } from '@/utils/events'; export function PromptConfigBaseForm({ @@ -33,8 +38,49 @@ export function PromptConfigBaseForm({ }) { const t = useTranslations('createConfigWizard'); - const modelChoices = useMemo(() => { - return Object.values(modelVendorTypeMap[modelVendor]) as string[]; + const activeModelVendor = useMemo( + () => + Object.entries(ModelVendor) + .filter(([, value]) => + [ModelVendor.OpenAI, ModelVendor.Cohere].includes(value), + ) + .map(([key, value]) => { + return ( + + ); + }), + [], + ); + + const invalidModelVendors = useMemo( + () => + Object.entries(UnavailableModelVendor).map(([key, value]) => ( + + )), + [], + ); + + useEffect(() => { + if ( + modelVendor === ModelVendor.Cohere && + !Object.values(CohereModelType).includes( + modelType as CohereModelType, + ) + ) { + setModelType(CohereModelType.Command); + } + if ( + modelVendor === ModelVendor.OpenAI && + !Object.values(OpenAIModelType).includes( + modelType as OpenAIModelType, + ) + ) { + setModelType(OpenAIModelType.Gpt35Turbo); + } }, [modelVendor]); return ( @@ -73,27 +119,10 @@ export function PromptConfigBaseForm({ className="card-select w-full" value={modelVendor} onChange={handleChange(setVendor)} - disabled={true} data-testid="create-prompt-base-form-vendor-select" > - {Object.entries(ModelVendor) - .filter( - ([, value]) => value === ModelVendor.OpenAI, - ) - .map(([key, value]) => { - return ( - - ); - })} - {Object.entries(UnavailableModelVendor).map( - ([key, value]) => ( - - ), - )} + {activeModelVendor} + {invalidModelVendors}
@@ -108,20 +137,22 @@ export function PromptConfigBaseForm({ onChange={handleChange(setModelType)} data-testid="create-prompt-base-form-model-select" > - {modelChoices.map((modelChoice) => { - return ( - - ); - })} + {Object.values(modelVendorTypeMap[modelVendor]).map( + (modelChoice: string) => { + return ( + + ); + }, + )}
diff --git a/frontend/src/components/projects/[projectId]/applications/[applicationId]/config-create-wizard/cohere-form-components.spec.tsx b/frontend/src/components/projects/[projectId]/applications/[applicationId]/config-create-wizard/cohere-form-components.spec.tsx new file mode 100644 index 00000000..79e80f54 --- /dev/null +++ b/frontend/src/components/projects/[projectId]/applications/[applicationId]/config-create-wizard/cohere-form-components.spec.tsx @@ -0,0 +1,245 @@ +import { CoherePromptParametersFactory } from 'tests/factories'; +import { + fireEvent, + getLocaleNamespace, + render, + screen, +} from 'tests/test-utils'; + +import { + CohereModelParametersForm, + CoherePromptTemplate, +} from '@/components/projects/[projectId]/applications/[applicationId]/config-create-wizard/cohere-form-components'; +import { CohereModelParameters, CohereModelType } from '@/types'; + +describe('Cohere Form Components', () => { + const locales = getLocaleNamespace('createConfigWizard'); + + describe('CohereModelParametersForm Tests', () => { + it('should render the component and all sliders using the provided existing parameters as initial values', () => { + const exitingParameters = + CoherePromptParametersFactory.buildSync() as Required; + + render( + , + ); + + const maxTokensSliderInput: HTMLInputElement = screen.getByTestId( + 'parameter-slider-maxTokens', + ); + expect(maxTokensSliderInput.value).toBe( + exitingParameters.maxTokens.toString(), + ); + + const temperatureSliderInput: HTMLInputElement = screen.getByTestId( + 'parameter-slider-temperature', + ); + expect(temperatureSliderInput.value).toBe( + exitingParameters.temperature.toString(), + ); + + const pSliderInput: HTMLInputElement = + screen.getByTestId('parameter-slider-p'); + expect(pSliderInput.value).toBe(exitingParameters.p.toString()); + + const kSliderInput: HTMLInputElement = + screen.getByTestId('parameter-slider-k'); + expect(kSliderInput.value).toBe(exitingParameters.k.toString()); + + const frequencyPenaltySliderInput: HTMLInputElement = + screen.getByTestId('parameter-slider-frequencyPenalty'); + expect(frequencyPenaltySliderInput.value).toBe( + exitingParameters.frequencyPenalty.toString(), + ); + + const presencePenaltySliderInput: HTMLInputElement = + screen.getByTestId('parameter-slider-presencePenalty'); + expect(presencePenaltySliderInput.value).toBe( + exitingParameters.presencePenalty.toString(), + ); + }); + + it('should update the value of the slider when it is adjusted by the user', () => { + const exitingParameters = + CoherePromptParametersFactory.buildSync() as Required; + + render( + , + ); + + const maxTokensSliderInput: HTMLInputElement = screen.getByTestId( + 'parameter-slider-maxTokens', + ); + + fireEvent.change(maxTokensSliderInput, { target: { value: '50' } }); + + expect(maxTokensSliderInput.value).toBe('50'); + }); + + it('should update the model parameters correctly when the slider is adjusted', () => { + const exitingParameters = + CoherePromptParametersFactory.buildSync() as Required; + + const setParametersMock = vi.fn(); + + render( + , + ); + + const maxTokensSliderInput: HTMLInputElement = screen.getByTestId( + 'parameter-slider-maxTokens', + ); + + fireEvent.change(maxTokensSliderInput, { target: { value: '50' } }); + + expect(setParametersMock).toHaveBeenCalledWith({ + ...exitingParameters, + maxTokens: 50, + }); + }); + }); + + describe('CoherePromptTemplate tests', () => { + it('should render the form with the correct header and label text', () => { + const setMessages = vi.fn(); + + render( + , + ); + + expect( + screen.getByTestId('cohere-prompt-template-form-header'), + ).toHaveTextContent(locales.promptTemplate); + expect( + screen.getByTestId('cohere-prompt-template-form-label-text'), + ).toHaveTextContent(locales.messageContent); + }); + + it('should display the message content textarea with the correct placeholder text', () => { + const setMessages = vi.fn(); + + render( + , + ); + + const textArea: HTMLTextAreaElement = screen.getByTestId( + 'cohere-prompt-template-form-textarea', + ); + + expect(textArea.placeholder).toBe( + locales.promptMessagePlaceholder.replaceAll("'", ''), + ); + }); + + it('should update the message state correctly when text is entered into the textarea', () => { + const setMessages = vi.fn(); + + render( + , + ); + + const textarea = screen.getByTestId( + 'cohere-prompt-template-form-textarea', + ); + fireEvent.change(textarea, { target: { value: 'Test message' } }); + + expect(setMessages).toHaveBeenCalledWith([ + { + message: 'Test message', + templateVariables: [], + }, + ]); + }); + + it('should extract template variables correctly from the entered message', () => { + const setMessages = vi.fn(); + + render( + , + ); + + const textarea = screen.getByTestId( + 'cohere-prompt-template-form-textarea', + ); + fireEvent.change(textarea, { target: { value: 'Hello {name}!' } }); + + expect(setMessages).toHaveBeenCalledWith([ + { + message: 'Hello {name}!', + templateVariables: ['name'], + }, + ]); + }); + + it('should render the form with empty message content when no messages are passed in', () => { + const setMessages = vi.fn(); + + render( + , + ); + + expect( + screen.getByTestId('cohere-prompt-template-form-textarea'), + ).toHaveValue(''); + }); + + it('should render the form with empty message content when messages array is empty', () => { + const setMessages = vi.fn(); + + render( + , + ); + + expect( + screen.getByTestId('cohere-prompt-template-form-textarea'), + ).toHaveValue(''); + }); + + it('should display a text info label with wrap variable instructions', () => { + const setMessages = vi.fn(); + + render( + , + ); + + expect( + screen.getByTestId( + 'cohere-prompt-template-form-label-alt-text', + ), + ).toHaveTextContent(locales.wrapVariable.replaceAll("'", '')); + }); + }); +}); diff --git a/frontend/src/components/projects/[projectId]/applications/[applicationId]/config-create-wizard/cohere-form-components.tsx b/frontend/src/components/projects/[projectId]/applications/[applicationId]/config-create-wizard/cohere-form-components.tsx new file mode 100644 index 00000000..cc901a2c --- /dev/null +++ b/frontend/src/components/projects/[projectId]/applications/[applicationId]/config-create-wizard/cohere-form-components.tsx @@ -0,0 +1,201 @@ +import { useTranslations } from 'next-intl'; +import { Dispatch, SetStateAction, useEffect, useState } from 'react'; + +import { ParameterSlider } from '@/components/prompt-display-components/parameter-slider'; +import { DEFAULT_MAX_TOKENS } from '@/constants/models'; +import { + CohereModelType, + CoherePromptMessage, + ModelParameters, + ModelVendor, + ProviderMessageType, +} from '@/types'; +import { handleChange } from '@/utils/events'; +import { extractTemplateVariables } from '@/utils/models'; + +const COHERE_MAX_TOKENS = 4096; + +export function CohereModelParametersForm({ + setParameters, + existingParameters, +}: { + existingParameters?: ModelParameters; + modelType: CohereModelType; + setParameters: (parameters: ModelParameters) => void; +}) { + const t = useTranslations('createConfigWizard'); + + const [maxTokens, setMaxTokens] = useState( + existingParameters?.maxTokens ?? DEFAULT_MAX_TOKENS, + ); + const [frequencyPenalty, setFrequencyPenalty] = useState( + existingParameters?.frequencyPenalty ?? 0, + ); + const [presencePenalty, setPresencePenalty] = useState( + existingParameters?.presencePenalty ?? 0, + ); + const [temperature, setTemperature] = useState( + existingParameters?.temperature ?? 0, + ); + const [p, setP] = useState(existingParameters?.p ?? 0); + const [k, setK] = useState(existingParameters?.k ?? 0); + + const inputs: { + key: string; + labelText: string; + max: number; + min: number; + setter: Dispatch>; + step: number; + toolTipText: string; + value: number; + }[] = [ + { + key: 'maxTokens', + labelText: t('cohereParametersMaxTokensLabel'), + max: COHERE_MAX_TOKENS, // cohere has a static max value for now + min: 1, + setter: setMaxTokens, + step: 1, + toolTipText: t('cohereParametersMaxTokensTooltip'), + value: maxTokens, + }, + { + key: 'temperature', + labelText: t('cohereParametersTemperatureLabel'), + max: 5, + min: 0, + setter: setTemperature, + step: 0.1, + toolTipText: t('cohereParametersTemperatureTooltip'), + value: temperature, + }, + { + key: 'p', + labelText: t('cohereParametersPLabel'), + max: 0.99, + min: 0, + setter: setP, + step: 0.01, + toolTipText: t('cohereParametersPTooltip'), + value: p, + }, + { + key: 'k', + labelText: t('cohereParametersKLabel'), + max: 500, + min: 0, + setter: setK, + step: 1, + toolTipText: t('cohereParametersKTooltip'), + value: k, + }, + { + key: 'frequencyPenalty', + labelText: t('cohereParametersFrequencyPenaltyLabel'), + max: 9999, + min: 0, + setter: setFrequencyPenalty, + step: 1, + toolTipText: t('cohereParametersFrequencyPenaltyTooltip'), + value: frequencyPenalty, + }, + { + key: 'presencePenalty', + labelText: t('cohereParametersPresencePenaltyLabel'), + max: 1, + min: 0, + setter: setPresencePenalty, + step: 0.1, + toolTipText: t('cohereParametersPresencePenaltyTooltip'), + value: presencePenalty, + }, + ]; + + useEffect(() => { + setParameters({ + frequencyPenalty, + k, + maxTokens, + p, + presencePenalty, + temperature, + }); + }, [maxTokens, frequencyPenalty, presencePenalty, temperature, p, k]); + + return ( +
+

+ {t('modelParameters')} +

+
+ {inputs.map(ParameterSlider)} +
+
+ ); +} + +export function CoherePromptTemplate({ + messages, + setMessages, +}: { + messages: ProviderMessageType[]; + setMessages: (messages: ProviderMessageType[]) => void; +}) { + const t = useTranslations('createConfigWizard'); + + const handleSetMessages = (message: string) => { + setMessages([ + { + message, + templateVariables: extractTemplateVariables(message), + } satisfies CoherePromptMessage, + ]); + }; + + return ( +
+

+ {t('promptTemplate')} +

+
+ +