Skip to content

Commit

Permalink
feat: updated prompt content display
Browse files Browse the repository at this point in the history
  • Loading branch information
Goldziher committed Jan 9, 2024
1 parent 87528ba commit d12d8dd
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ function shouldAllowContinue(store: PromptConfigWizardStore) {
return false;
}
if (store.modelVendor === ModelVendor.Cohere) {
const [message] = store.messages as CoherePromptMessage[];
return message.message.trim().length;
const messages = store.messages as CoherePromptMessage[];
return !!messages[0]?.message.trim().length;
}
return store.messages.every(
(message) => (message as OpenAIPromptMessage).content.trim().length,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ export function PromptConfigTestingForm<T extends ModelVendor>({
<PromptContentDisplay
modelVendor={modelVendor}
messages={messages}
templateVariables={templateVariables}
/>
</>
{expectedVariables.length > 0 && (
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import { waitFor } from '@testing-library/react';
import {
CohereMessageFactory,
OpenAIPromptMessageFactory,
} from 'tests/factories';
import { getLocaleNamespace, render, screen } from 'tests/test-utils';
import { render, screen } from 'tests/test-utils';
import { expect } from 'vitest';

import { PromptContentDisplay } from '@/components/prompt-display-components/prompt-content-display';
import { ModelVendor } from '@/types';

describe('PromptContentDisplay', () => {
const namespace = getLocaleNamespace('createConfigWizard');

it('should display an array of openAI messages correctly', () => {
const messages = OpenAIPromptMessageFactory.batchSync(10);

Expand All @@ -22,7 +21,7 @@ describe('PromptContentDisplay', () => {
);

const messageContainer = screen.getByTestId(
'prompt-content-display-messages',
'prompt-content-display-container',
);
expect(messageContainer).toBeInTheDocument();
expect(messageContainer.children.length).toBe(10);
Expand All @@ -33,36 +32,86 @@ describe('PromptContentDisplay', () => {
}
});

it('should display an array of cohere messages correctly', () => {
const messages = CohereMessageFactory.batchSync(5);
it('should display a cohere message correctly', () => {
const message = CohereMessageFactory.buildSync();
render(
<PromptContentDisplay
messages={messages}
messages={[message]}
modelVendor={ModelVendor.Cohere}
/>,
);
const messageContainer = screen.getByTestId(
'prompt-content-display-messages',
'prompt-content-display-container',
);
expect(messageContainer).toBeInTheDocument();
expect(messageContainer.children.length).toBe(5);

for (const [i, message] of messages.entries()) {
const content = `[${i}]: ${message.message}`;
expect(content).toBe(messageContainer.children[i].textContent);
}
expect(messageContainer).toHaveTextContent(message.message);
});

it('should have the correct title', () => {
const messages = CohereMessageFactory.batchSync(5);
it('should parse template variables correctly for OpenAI messages', async () => {
const messages = OpenAIPromptMessageFactory.batchSync(3);
messages[0].content = 'Hello {var1}';
messages[1].content = 'Hello {var2}';
messages[2].content = 'Hello {var3}';

const templateVariables = {
var1: 'world',
var2: 'there',
};

render(
<PromptContentDisplay
messages={messages}
modelVendor={ModelVendor.OpenAI}
templateVariables={templateVariables}
/>,
);
const messageContainer = screen.getByTestId(
'prompt-content-display-container',
);
expect(messageContainer).toBeInTheDocument();

const variables = screen.queryAllByTestId('template-variable');
await waitFor(() => {
expect(variables.length).toBe(3);
});

expect(variables[0]).toHaveTextContent('world');
expect(variables[1]).toHaveTextContent('there');
expect(variables[2]).toHaveTextContent('{var3}');
});

it('should parse template variables correctly for OpenAI messages', async () => {
const message = CohereMessageFactory.buildSync({
message: `
Hello {var1}
Hello {var2}
Hello {var3}
`,
});
const templateVariables = {
var1: 'world',
var2: 'there',
};

render(
<PromptContentDisplay
messages={[message]}
modelVendor={ModelVendor.Cohere}
templateVariables={templateVariables}
/>,
);
const title = screen.getByTestId('prompt-content-display-title');
expect(title).toBeInTheDocument();
expect(title.textContent).toBe(namespace.promptTemplate);
const messageContainer = screen.getByTestId(
'prompt-content-display-container',
);
expect(messageContainer).toBeInTheDocument();

const variables = screen.queryAllByTestId('template-variable');
await waitFor(() => {
expect(variables.length).toBe(3);
});

expect(variables[0]).toHaveTextContent('world');
expect(variables[1]).toHaveTextContent('there');
expect(variables[2]).toHaveTextContent('{var3}');
});
});
Original file line number Diff line number Diff line change
@@ -1,80 +1,114 @@
import { ReactNode, useEffect, useState } from 'react';
import reactStringReplace from 'react-string-replace';

import { openAIRoleColorMap } from '@/constants/models';
import {
ModelVendor,
OpenAIPromptMessageRole,
ProviderMessageType,
} from '@/types';
import { curlyBracketsRe } from '@/utils/models';

const openAIRoleElementMapper: Record<
OpenAIPromptMessageRole,
React.FC<{ message: string }>
React.FC<{ message: string; templateVariables?: Record<string, string> }>
> = {
[OpenAIPromptMessageRole.Assistant]: ({ message }) => (
[OpenAIPromptMessageRole.Assistant]: ({ message, templateVariables }) => (
<span>
[
<span className={`text-${openAIRoleColorMap.assistant}`}>
{`${OpenAIPromptMessageRole.Assistant} message`}
</span>
]: <span className="text-base-content">{message}</span>
]:{' '}
<span className="text-base-content">
{parseTemplateVariables(message, templateVariables)}
</span>
</span>
),

[OpenAIPromptMessageRole.System]: ({ message }) => (
[OpenAIPromptMessageRole.System]: ({ message, templateVariables }) => (
<span>
[
<span
className={`text-${openAIRoleColorMap.system}`}
>{`${OpenAIPromptMessageRole.System} message`}</span>
]: <span className="text-base-content">{message}</span>
]:{' '}
<span className="text-base-content">
{parseTemplateVariables(message, templateVariables)}
</span>
</span>
),
[OpenAIPromptMessageRole.User]: ({ message }) => (
[OpenAIPromptMessageRole.User]: ({ message, templateVariables }) => (
<span>
[
<span
className={`text-${openAIRoleColorMap.user}`}
>{`${OpenAIPromptMessageRole.User} message`}</span>
]: <span className="text-base-content">{message}</span>
]:{' '}
<span className="text-base-content">
{parseTemplateVariables(message, templateVariables)}
</span>
</span>
),
};

const contentMapper: Record<
ModelVendor,
(message: any, index: number) => React.ReactNode
(
templateVariables?: Record<string, string>,
) => (message: any, index: number) => React.ReactNode
> = {
[ModelVendor.Cohere]: (
m: ProviderMessageType<ModelVendor.Cohere>,
i: number,
) => (
<span>
[<span className="text-accent">{i}</span>]:{' '}
<span className="text-info">{m.message}</span>
[ModelVendor.Cohere]:
(templateVariables?: Record<string, string>) =>
(m: ProviderMessageType<ModelVendor.Cohere>) => (
<span className="text-info">
{parseTemplateVariables(m.message, templateVariables)}
</span>
),
[ModelVendor.OpenAI]:
(templateVariables?: Record<string, string>) =>
(m: ProviderMessageType<ModelVendor.OpenAI>) =>
openAIRoleElementMapper[m.role]({
message: m.content,
templateVariables,
}),
};

const parseTemplateVariables = (
content: string,
templateVariables?: Record<string, string>,
) => {
return reactStringReplace(content, curlyBracketsRe, (match, i) => (
<span key={i} className="text-primary" data-testid="template-variable">
{templateVariables?.[match] ?? `{${match}}`}
</span>
),
[ModelVendor.OpenAI]: (m: ProviderMessageType<ModelVendor.OpenAI>) =>
openAIRoleElementMapper[m.role]({
message: m.content,
}),
));
};

export function PromptContentDisplay<T extends ModelVendor>({
messages,
modelVendor,
className = 'rounded-dark-card pt-2 pb-2',
templateVariables,
}: {
className?: string;
messages: ProviderMessageType<T>[];
modelVendor: ModelVendor;
templateVariables?: Record<string, string>;
}) {
const messageContent = messages.map(contentMapper[modelVendor]);
const [parsedMessages, setParsedMessages] = useState<ReactNode[]>([]);

useEffect(() => {
const mapper = contentMapper[modelVendor](templateVariables);
setParsedMessages(messages.map(mapper));
}, [messages, modelVendor, templateVariables]);

return (
<div
className={className}
data-testid="prompt-content-display-container"
>
{messageContent.map((msg, i) => (
{parsedMessages.map((msg, i) => (
<p
className="text-base-content"
data-testid="message-content-paragraph"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ export function PromptTestInputs({
{expectedVariables.map((variable) => (
<div key={variable} className="form-control">
<textarea
className="card-textarea textarea-info bg-neutral placeholder-info"
className="card-textarea textarea-info bg-neutral placeholder-info min-w-full"
data-testid={`input-variable-input-${variable}`}
value={templateVariables[variable] ?? ''}
placeholder={`{${variable}}`}
Expand Down

0 comments on commit d12d8dd

Please sign in to comment.