Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: cohere frontend #281

Merged
merged 15 commits into from
Jan 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .idea/externalDependencies.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions .idea/runConfigurations/Cohere_Connector.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ services:
GOOGLE_APPLICATION_CREDENTIALS: ./.secrets/serviceAccountKey.json
JWT_SECRET: jeronimo
OPENAI_CONNECTOR_ADDRESS: openai-connector:8000
COHERE_CONNECTOR_ADDRESS: cohere-connector:9000
REDIS_CONNECTION_STRING: redis://redis:6379
SERVER_HOST: localhost
SERVER_PORT: 4000
Expand Down
48 changes: 32 additions & 16 deletions e2e/factories/factories.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,16 @@ func CreateOpenAIPromptMessages(
return ptr.To(json.RawMessage(serialization.SerializeJSON(msgs)))
}

func CreateModelParameters() *json.RawMessage {
modelParameters := serialization.SerializeJSON(map[string]float32{
"temperature": 1,
"top_p": 1,
"max_tokens": 1,
"presence_penalty": 1,
"frequency_penalty": 1,
func CreateOpenAIModelParameters() *json.RawMessage {
floatValue := float32(0)
intValue := int32(0)

modelParameters := serialization.SerializeJSON(datatypes.OpenAIModelParametersDTO{
Temperature: &floatValue,
TopP: &floatValue,
MaxTokens: &intValue,
FrequencyPenalty: &floatValue,
PresencePenalty: &floatValue,
})

return ptr.To(json.RawMessage(modelParameters))
Expand Down Expand Up @@ -107,7 +110,7 @@ func CreateOpenAIPromptConfig(
userMessage := "This is what the user asked for: {userInput}"
templateVariables := []string{"userInput"}

modelParameters := CreateModelParameters()
modelParameters := CreateOpenAIModelParameters()

promptMessages := CreateOpenAIPromptMessages(
systemMessage,
Expand All @@ -132,24 +135,36 @@ func CreateOpenAIPromptConfig(
return &promptConfig, nil
}

func CreateCohereModelParameters() *json.RawMessage {
floatValue := float32(0)
uintValue := uint32(0)
intValue := int32(0)

modelParameters := serialization.SerializeJSON(datatypes.CohereModelParametersDTO{
Temperature: &floatValue,
K: &uintValue,
P: &floatValue,
FrequencyPenalty: &floatValue,
PresencePenalty: &floatValue,
MaxTokens: &intValue,
})
return ptr.To(json.RawMessage(modelParameters))
}

func CreateCoherePromptConfig(
ctx context.Context,
applicationID pgtype.UUID,
) (*models.PromptConfig, error) {
userMessage := "This is what the user asked for: {userInput}"
templateVariables := []string{"userInput"}

modelParameters := ptr.To(
json.RawMessage(serialization.SerializeJSON(datatypes.CohereModelParametersDTO{
Temperature: ptr.To(float32(1)),
})),
)
modelParameters := CreateCohereModelParameters()

promptMessages := ptr.To(
json.RawMessage(serialization.SerializeJSON(datatypes.CoherePromptMessageDTO{
Content: userMessage,
json.RawMessage(serialization.SerializeJSON([]datatypes.CoherePromptMessageDTO{{
Message: userMessage,
TemplateVariables: &templateVariables,
})),
}})),
)

promptConfig, promptConfigCreateErr := db.GetQueries().
Expand Down Expand Up @@ -186,6 +201,7 @@ func CreatePromptRequestRecord(
ResponseTokens: 18,
RequestTokensCost: *requestTokenCost,
ResponseTokensCost: *responseTokenCost,
FinishReason: models.PromptFinishReasonDONE,
StartTime: pgtype.Timestamptz{Time: promptStartTime, Valid: true},
FinishTime: pgtype.Timestamptz{Time: promptFinishTime, Valid: true},
DurationMs: pgtype.Int4{Int32: 0, Valid: true},
Expand Down
10 changes: 5 additions & 5 deletions frontend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"firebaseui": "^6.1.0",
"next": "14.0.4",
"next-intl": "^3.4.0",
"persist-and-sync": "^1.1.1",
"persist-and-sync": "^1.2.0",
"react": "18.2.0",
"react-bootstrap-icons": "^1.10.3",
"react-dom": "18.2.0",
Expand All @@ -38,9 +38,9 @@
"@tailwindcss/forms": "^0.5.7",
"@tailwindcss/typography": "^0.5.10",
"@testing-library/dom": "^9.3.3",
"@testing-library/jest-dom": "^6.1.5",
"@testing-library/jest-dom": "^6.1.6",
"@testing-library/react": "^14.1.2",
"@types/react": "18.2.45",
"@types/react": "18.2.46",
"@types/react-dom": "18.2.18",
"@types/react-syntax-highlighter": "^15.5.11",
"@types/validator": "^13.11.7",
Expand All @@ -50,8 +50,8 @@
"jsdom": "^23.0.1",
"mock-socket": "^9.3.1",
"postcss": "8.4.32",
"sass": "^1.69.5",
"sass": "^1.69.6",
"tailwindcss": "3.4.0",
"vite-plugin-magical-svg": "^1.1.0"
"vite-plugin-magical-svg": "^1.1.1"
}
}
15 changes: 14 additions & 1 deletion frontend/public/messages/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,19 @@
"backButtonText": "Back",
"cancelButtonText": "Cancel",
"cohere": "Cohere",
"cohereParametersFrequencyPenaltyLabel": "Frequency Penalty",
"cohereParametersFrequencyPenaltyTooltip": "Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.",
"cohereParametersKLabel": "K",
"cohereParametersKTooltip": "Ensures only the top k most likely tokens are considered for generation at each step.",
"cohereParametersMaxTokensLabel": "Max Tokens",
"cohereParametersMaxTokensTooltip": "The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations.",
"cohereParametersPLabel": "P",
"cohereParametersPTooltip": "Ensures that only the most likely tokens, with total probability mass of p, are considered for generation at each step. If both k and p are enabled, p acts after k.",
"cohereParametersPresencePenaltyLabel": "Presence Penalty",
"cohereParametersPresencePenaltyTooltip": "Can be used to reduce repetitiveness of generated tokens. Similar to frequency_penalty, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.",

"cohereParametersTemperatureLabel": "Temperature",
"cohereParametersTemperatureTooltip": "A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations.",
"configName": "Config Name",
"continueButtonText": "Continue",
"createConfig": "Create Configuration",
Expand Down Expand Up @@ -121,7 +134,7 @@
"promptConfigModelSelectLabel": "Model",
"promptConfigNameInputPlaceholder": "For example: Default Config",
"promptConfigVendorSelectLabel": "Provider",
"promptMessagePlaceholder": "Example of a template: The user favourite albums are: '{favouriteAlbums'}.The user disliked albums are: '{dislikedAlbums'}. Please suggest 5 new albums for the user.",
"promptMessagePlaceholder": "Example of a template: The user favourite albums are: '{favouriteAlbums'}. The user disliked albums are: '{dislikedAlbums'}. Please suggest 5 new albums for the user.",
"promptMessages": "Messages",
"promptTemplate": "Prompt Template",
"requestTokens": "Request Tokens",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ export default function PromptConfigCreateWizard({
<button
data-testid="config-create-wizard-cancel-button"
onClick={() => {
store.resetState();
router.push(
setRouteParams(
Navigation.ApplicationDetail,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,9 @@ describe('PromptConfigBaseForm', () => {
expect(setConfigName).toHaveBeenCalledWith('New Config');
});

it('should display the model vendor select field as disabled', () => {
it('should display and allow selection of model vendor', () => {
const modelVendor = ModelVendor.OpenAI;
const setVendor = vi.fn();
render(
<PromptConfigBaseForm
validateConfigName={vi.fn()}
Expand All @@ -186,15 +187,16 @@ describe('PromptConfigBaseForm', () => {
modelVendor={modelVendor}
setConfigName={vi.fn()}
setModelType={vi.fn()}
setVendor={vi.fn()}
setVendor={setVendor}
/>,
);
const select: HTMLInputElement = screen.getByTestId(
'create-prompt-base-form-vendor-select',
);
expect(select).toBeInTheDocument();
expect(select.value).toBe(modelVendor);
expect(select.disabled).toBe(true);
fireEvent.change(select, { target: { value: ModelVendor.Cohere } });
expect(setVendor).toHaveBeenCalledWith(ModelVendor.Cohere);
});

it('should display and allow editing of the model type select field based on the selected model vendor', () => {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import { useTranslations } from 'next-intl';
import { useMemo } from 'react';
import { useEffect, useMemo } from 'react';

import { EntityNameInput } from '@/components/entity-name-input';
import {
modelTypeToLocaleMap,
modelVendorTypeMap,
UnavailableModelVendor,
} from '@/constants/models';
import { ModelType, ModelVendor } from '@/types';
import {
CohereModelType,
ModelType,
ModelVendor,
OpenAIModelType,
} from '@/types';
import { handleChange } from '@/utils/events';

export function PromptConfigBaseForm({
Expand All @@ -33,8 +38,49 @@ export function PromptConfigBaseForm({
}) {
const t = useTranslations('createConfigWizard');

const modelChoices = useMemo(() => {
return Object.values(modelVendorTypeMap[modelVendor]) as string[];
const activeModelVendor = useMemo(
() =>
Object.entries(ModelVendor)
.filter(([, value]) =>
[ModelVendor.OpenAI, ModelVendor.Cohere].includes(value),
)
.map(([key, value]) => {
return (
<option key={key} value={value}>
{key}
</option>
);
}),
[],
);

const invalidModelVendors = useMemo(
() =>
Object.entries(UnavailableModelVendor).map(([key, value]) => (
<option disabled key={value} value={value}>
{key}
</option>
)),
[],
);

useEffect(() => {
if (
modelVendor === ModelVendor.Cohere &&
!Object.values(CohereModelType).includes(
modelType as CohereModelType,
)
) {
setModelType(CohereModelType.Command);
}
if (
modelVendor === ModelVendor.OpenAI &&
!Object.values(OpenAIModelType).includes(
modelType as OpenAIModelType,
)
) {
setModelType(OpenAIModelType.Gpt35Turbo);
}
}, [modelVendor]);

return (
Expand Down Expand Up @@ -73,27 +119,10 @@ export function PromptConfigBaseForm({
className="card-select w-full"
value={modelVendor}
onChange={handleChange(setVendor)}
disabled={true}
data-testid="create-prompt-base-form-vendor-select"
>
{Object.entries(ModelVendor)
.filter(
([, value]) => value === ModelVendor.OpenAI,
)
.map(([key, value]) => {
return (
<option key={key} value={value}>
{key}
</option>
);
})}
{Object.entries(UnavailableModelVendor).map(
([key, value]) => (
<option disabled key={value} value={value}>
{key}
</option>
),
)}
{activeModelVendor}
{invalidModelVendors}
</select>
</div>
<div className="form-control">
Expand All @@ -108,20 +137,22 @@ export function PromptConfigBaseForm({
onChange={handleChange(setModelType)}
data-testid="create-prompt-base-form-model-select"
>
{modelChoices.map((modelChoice) => {
return (
<option
key={modelChoice}
value={modelChoice}
>
{
modelTypeToLocaleMap[
modelChoice as ModelType<any>
]
}
</option>
);
})}
{Object.values(modelVendorTypeMap[modelVendor]).map(
(modelChoice: string) => {
return (
<option
key={modelChoice}
value={modelChoice}
>
{
modelTypeToLocaleMap[
modelChoice as ModelType<any>
]
}
</option>
);
},
)}
</select>
</div>
</div>
Expand Down
Loading
Loading