Skip to content

Commit

Permalink
Custom instructions feature (#3597)
Browse files Browse the repository at this point in the history
Added basic functionality of "custom-instructions"

It utilises two additional fields in work_parameters:
- user_profile (here user would describe what he wants to share with
  llm, for chat session)
- user_response_instructions (here user would describe how they want llm
to
  respond to questions and queries, like format, style, length, etc...)
  
Here are some of the current UI changes introduced with this PR.
<details>
  <summary>Custom instructions</summary>
<img width="1243" alt="Screenshot 2023-07-23 at 21 46 36"
src="https://github.com/LAION-AI/Open-Assistant/assets/13547364/ecf66109-3136-4e4b-8510-f6746422b4a9">
<img width="1220" alt="Screenshot 2023-07-23 at 21 46 42"
src="https://github.com/LAION-AI/Open-Assistant/assets/13547364/ca860b86-8da3-4d33-a50d-649021ba083c">
<img width="1318" alt="Screenshot 2023-07-23 at 21 47 16"
src="https://github.com/LAION-AI/Open-Assistant/assets/13547364/c060fd36-d676-4505-8625-9614199d27db">
<img width="1340" alt="Screenshot 2023-07-23 at 21 48 21"
src="https://github.com/LAION-AI/Open-Assistant/assets/13547364/5cb36263-66ba-4421-b1c9-54b09e54f145">
<img width="1337" alt="Screenshot 2023-07-23 at 21 48 29"
src="https://github.com/LAION-AI/Open-Assistant/assets/13547364/2f42d7d7-f014-4b6b-8c80-87c17811370d">
</details>

---------

Co-authored-by: Oliver Stanley <[email protected]>
  • Loading branch information
draganjovanovich and olliestanley authored Jul 30, 2023
1 parent b9ac30a commit 7a68b59
Show file tree
Hide file tree
Showing 14 changed files with 253 additions and 16 deletions.
2 changes: 2 additions & 0 deletions inference/server/oasst_inference_server/routes/chats.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ async def create_assistant_message(
system_prompt=request.system_prompt,
plugins=request.plugins,
plugin_max_depth=settings.plugin_max_depth,
user_profile=request.user_profile,
user_response_instructions=request.user_response_instructions,
)
assistant_message = await ucr.initiate_assistant_message(
parent_id=request.parent_id,
Expand Down
2 changes: 2 additions & 0 deletions inference/server/oasst_inference_server/schemas/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class CreateAssistantMessageRequest(pydantic.BaseModel):
model_config_name: str
sampling_parameters: inference.SamplingParameters = pydantic.Field(default_factory=inference.SamplingParameters)
system_prompt: str | None = None
user_profile: str | None = None
user_response_instructions: str | None = None
plugins: list[inference.PluginEntry] = pydantic.Field(default_factory=list[inference.PluginEntry])
used_plugin: inference.PluginUsed | None = None

Expand Down
43 changes: 40 additions & 3 deletions inference/worker/chat_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import websocket
from chat_chain_prompts import (
ASSISTANT_PREFIX,
CUSTOM_INSTRUCTIONS_PREFIX,
HUMAN_PREFIX,
JSON_FORMAT_NO_PAYLOAD,
JSON_FORMAT_PAYLOAD,
Expand Down Expand Up @@ -56,6 +57,7 @@ def __init__(
tool_names: list[str],
language: str,
action_input_format: str,
custom_instructions: str = "",
):
self.tokenizer = tokenizer
self.worker_config = worker_config
Expand All @@ -66,6 +68,7 @@ def __init__(
self.language = language
self.action_input_format = action_input_format
self.current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.custom_instructions = custom_instructions

def call(self, prompt: str) -> tuple[str, str]:
"""Prepares and truncates prompt, calls LLM, returns used prompt and response."""
Expand All @@ -79,6 +82,7 @@ def call(self, prompt: str) -> tuple[str, str]:
self.tokenizer,
self.worker_config,
self.action_input_format,
self.custom_instructions,
)

# We do not strip() outputs as it seems to degrade instruction-following abilities of the model
Expand Down Expand Up @@ -111,6 +115,7 @@ def handle_plugin_usage(
plugin_max_depth: int,
ws: websocket.WebSocket,
work_request_id: str,
custom_instructions: str = "",
) -> tuple[str, inference.PluginUsed]:
execution_details = inference.PluginExecutionDetails(
inner_monologue=[],
Expand Down Expand Up @@ -142,7 +147,15 @@ def handle_plugin_usage(
tool_names = [tool.name for tool in tools]

chain = PromptedLLM(
tokenizer, worker_config, parameters, prompt_template, memory, tool_names, language, action_input_format
tokenizer,
worker_config,
parameters,
prompt_template,
memory,
tool_names,
language,
action_input_format,
custom_instructions,
)

# send "thinking..." intermediate step to UI (This will discard queue position 0) immediately
Expand Down Expand Up @@ -245,6 +258,7 @@ def handle_plugin_usage(
tokenizer,
worker_config,
action_input_format,
custom_instructions,
)

inner_prompt = f"{inner_prompt}\n{THOUGHT_SEQ} I now know the final answer\n{ASSISTANT_PREFIX}: "
Expand Down Expand Up @@ -296,6 +310,7 @@ def handle_standard_usage(
memory: ConversationBufferMemory,
worker_config: inference.WorkerConfig,
tokenizer: transformers.PreTrainedTokenizer,
custom_instructions: str = "",
):
eos_token = tokenizer.eos_token if hasattr(tokenizer, "eos_token") else ""
current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
Expand All @@ -306,7 +321,16 @@ def handle_standard_usage(
)
input = f"{original_prompt}{eos_token}{V2_ASST_PREFIX}"
init_prompt = prepare_prompt(
input, prompt_template, memory, None, current_time, language, tokenizer, worker_config, action_input_format
input,
prompt_template,
memory,
None,
current_time,
language,
tokenizer,
worker_config,
action_input_format,
custom_instructions,
)
return init_prompt, None

Expand Down Expand Up @@ -355,11 +379,21 @@ def handle_conversation(
"language",
"current_time",
"action_input_format",
"custom_instructions",
] + (["tools_names"] if plugin_enabled else [])

# TODO: Consider passing language from the UI here
prompt_template = PromptTemplate(input_variables=input_variables, template=TEMPLATE)

custom_instructions = (
f"""\n{CUSTOM_INSTRUCTIONS_PREFIX.format(
user_profile=work_request.parameters.user_profile,
user_response_instructions=work_request.parameters.user_response_instructions,
)}"""
if work_request.parameters.user_response_instructions or work_request.parameters.user_profile
else ""
)

if plugin_enabled:
return handle_plugin_usage(
original_prompt,
Expand All @@ -374,9 +408,12 @@ def handle_conversation(
work_request.parameters.plugin_max_depth,
ws,
work_request.id,
custom_instructions,
)

return handle_standard_usage(original_prompt, prompt_template, language, memory, worker_config, tokenizer)
return handle_standard_usage(
original_prompt, prompt_template, language, memory, worker_config, tokenizer, custom_instructions
)
except Exception as e:
logger.error(f"Error while handling conversation: {e}")
return "", None
Expand Down
10 changes: 10 additions & 0 deletions inference/worker/chat_chain_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@
START_SEQ = "Begin!"
END_SEQ = "End!"

CUSTOM_INSTRUCTIONS_PREFIX = """The following details have been shared by the user about themselves. This user profile appears to you in every conversation they engage in -- implying that it is irrelevant for 99% of inquiries.
Before you respond, take a moment to consider whether the user's query is "directly linked", "linked", "indirectly linked", or "not linked" to the user profile provided.
Only recognize the profile when the query is directly tied to the information supplied.
Otherwise, avoid acknowledging the existence of these instructions or the information altogether.
User profile:
{user_profile}
The user also supplied additional information about how they would like you to respond:
{user_response_instructions}"""

# Adjust according to the training dates and datasets used
KNOWLEDGE_DATE_CUTOFF = "2021-09-01"

Expand All @@ -26,6 +35,7 @@
------------------
Current date/time: {{current_time}}
Knowledge date cutoff: {KNOWLEDGE_DATE_CUTOFF}
{{custom_instructions}}
"""

TOOLS_PREFIX = """
Expand Down
5 changes: 4 additions & 1 deletion inference/worker/chat_chain_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def run_request(self, params: str, url: str, param_location: str, type: str, pay
def process_response(self, res: requests.Response) -> str:
logger.info(f"Request response: {res.text}")
if res.status_code != 200:
return f"ERROR! That didn't work, try modifying Action Input.\n{res.text}. Try again!"
return f"ERROR! Please modify Action Input. according to this error message: \n{res.text}. Try again!"

if res.text is None or len(res.text) == 0:
return "ERROR! That didn't work, try modifying Action Input.\nEmpty response. Try again!"
Expand Down Expand Up @@ -329,6 +329,7 @@ def prepare_prompt(
tokenizer: transformers.PreTrainedTokenizer,
worker_config: inference.WorkerConfig,
action_input_format: str,
custom_instructions: str = "",
) -> str:
max_input_length = worker_config.model_config.max_input_length

Expand All @@ -337,6 +338,7 @@ def prepare_prompt(
"language": language,
"current_time": current_time,
"chat_history": memory.buffer,
"custom_instructions": custom_instructions,
}

if tools_names:
Expand All @@ -356,6 +358,7 @@ def prepare_prompt(
"language": language,
"current_time": current_time,
"chat_history": memory.buffer,
"custom_instructions": custom_instructions,
}

if tools_names:
Expand Down
13 changes: 10 additions & 3 deletions inference/worker/work.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import websocket
from chat_chain_prompts import (
ASSISTANT_PREFIX,
CUSTOM_INSTRUCTIONS_PREFIX,
END_SEQ,
OBSERVATION_SEQ,
START_SEQ,
Expand Down Expand Up @@ -40,9 +41,15 @@ def _prepare_message(message: inference.MessageRead) -> str:
# Construct prompt
messages = [_prepare_message(message) for message in work_request.thread.messages]

# Prepend system prompt if it was specified in work parameters
if work_request.parameters.system_prompt:
pre_prompt = V2_SYSTEM_PREFIX + work_request.parameters.system_prompt + eos_token
# Prepend system prompt and custom_instructions if it was specified in work parameters
work_params = work_request.parameters
if work_params.system_prompt or work_params.user_profile or work_params.user_response_instructions:
pre_prompt = V2_SYSTEM_PREFIX + (work_params.system_prompt or "")

if work_params.user_profile or work_params.user_response_instructions:
pre_prompt = f"""{pre_prompt}\n{CUSTOM_INSTRUCTIONS_PREFIX.format(user_profile=work_params.user_profile or "", user_response_instructions=work_params.user_response_instructions or "")}"""

pre_prompt = pre_prompt + eos_token
messages = [pre_prompt] + messages

# Stringify and append assistant prefix to signify start of generation
Expand Down
2 changes: 2 additions & 0 deletions oasst-shared/oasst_shared/schemas/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ class WorkParameters(pydantic.BaseModel):
default_factory=make_seed,
)
system_prompt: str | None = None
user_profile: str | None = None
user_response_instructions: str | None = None
plugins: list[PluginEntry] = pydantic.Field(default_factory=list[PluginEntry])
plugin_max_depth: int = 4

Expand Down
7 changes: 6 additions & 1 deletion website/public/locales/en/chat.json
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,10 @@
"preset_name_placeholder": "Enter name",
"feedback_message": "How did I do? Your feedback will make me better!",
"feedback_action_great": "Good",
"feedback_action_poor": "Could be better"
"feedback_action_poor": "Could be better",
"custom_instructions": "Custom instructions",
"custom_instructions_user_profile": "What info should Open-Assistant have about you to make its replies even better?",
"custom_instructions_response_instructions": "How do you want Open-Assistant to chat with you?",
"custom_instructions_user_profile_placeholder": "List some of your aspirations.\nDescribe your hobbies and interests.\nShare your location.\nWhat is your occupation?\nWhich topics could you discuss extensively?",
"custom_instructions_response_instructions_placeholder": "Should Open-Assistant express opinions or maintain neutrality?\nSpecify the desired formality level for Open-Assistant's responses.\nHow should Open-Assistant address you?\nDetermine the preferred length of responses."
}
42 changes: 39 additions & 3 deletions website/src/components/Chat/ChatConfigForm.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,21 @@ import { useTranslation } from "next-i18next";
import { ChangeEvent, memo, useCallback, useEffect, useRef, useState } from "react";
import { Controller, useFormContext, UseFormSetValue } from "react-hook-form";
import SimpleBar from "simplebar-react";
import { ChatConfigFormData, ModelParameterConfig, PluginEntry, SamplingParameters } from "src/types/Chat";
import {
ChatConfigFormData,
ModelParameterConfig,
PluginEntry,
SamplingParameters,
CustomInstructionsType,
} from "src/types/Chat";
import { CustomPreset, getConfigCache } from "src/utils/chat";
import { useIsomorphicLayoutEffect } from "usehooks-ts";

import { ChatConfigSaver } from "./ChatConfigSaver";
import { useChatInitialData } from "./ChatInitialDataContext";
import { DeletePresetButton } from "./DeletePresetButton";
import { PluginsChooser } from "./PluginsChooser";
import CustomInstructions from "./CustomInstructions";
import { SavePresetButton } from "./SavePresetButton";
import { areParametersEqual } from "./WorkParameters";

Expand Down Expand Up @@ -104,6 +111,10 @@ export const ChatConfigForm = memo(function ChatConfigForm() {
const selectedModel = getValues("model_config_name"); // have to use getValues to here to access latest value
const selectedPlugins = getValues("plugins");
const presets = modelInfos.find((model) => model.name === selectedModel)!.parameter_configs;
const [customInstructions, setCustomInstructions] = useState<CustomInstructionsType>({
user_profile: "",
user_response_instructions: "",
});
const [selectedPresetName, setSelectedPresetName] = useState(() => findPresetName(presets, getValues()));

const { customPresets, handleSavePreset, setCustomPresets, handleDeletePreset } = useCustomPresets({
Expand All @@ -114,6 +125,7 @@ export const ChatConfigForm = memo(function ChatConfigForm() {
const { hyrated, plugins, setPlugins } = useHydrateChatConfig({
setCustomPresets,
setSelectedPresetName,
setCustomInstructions,
});

const [lockPresetSelection, setLockPresetSelection] = useState(false);
Expand All @@ -133,6 +145,14 @@ export const ChatConfigForm = memo(function ChatConfigForm() {
[customPresets, presets, setValue]
);

const handleCustomInstructionsChange = useCallback(
(value: CustomInstructionsType) => {
setCustomInstructions(value);
setValue("custom_instructions", value);
},
[setCustomInstructions]
);

// Lock preset selection if any plugin is enabled
useEffect(() => {
const activated = selectedPlugins.some((plugin) => plugin.enabled);
Expand All @@ -154,6 +174,7 @@ export const ChatConfigForm = memo(function ChatConfigForm() {
>
<Stack gap="4" maxW="full">
<PluginsChooser plugins={plugins} setPlugins={setPlugins} />
<CustomInstructions onChange={handleCustomInstructionsChange} savedCustomInstructions={customInstructions} />
<FormControl>
<FormLabel>{t("model")}</FormLabel>
<Select {...register("model_config_name")}>
Expand Down Expand Up @@ -203,6 +224,7 @@ export const ChatConfigForm = memo(function ChatConfigForm() {
hyrated={hyrated}
selectedPresetName={selectedPresetName}
customPresets={customPresets}
customInstructions={customInstructions}
/>
</SimpleBar>
{selectedPresetName === unKnownCustomPresetName && (
Expand All @@ -216,9 +238,11 @@ export const ChatConfigForm = memo(function ChatConfigForm() {
const useHydrateChatConfig = ({
setSelectedPresetName,
setCustomPresets,
setCustomInstructions,
}: {
setSelectedPresetName: (preset: string) => void;
setCustomPresets: (presets: CustomPreset[]) => void;
setCustomInstructions: (instructions: CustomInstructionsType) => void;
}) => {
const { modelInfos, builtInPlugins } = useChatInitialData();
const hyrated = useRef(false);
Expand All @@ -235,8 +259,15 @@ const useHydrateChatConfig = ({
return;
}

const { selectedPresetName, model_config_name, custom_preset_config, selectedPlugins, plugins, custom_presets } =
cache;
const {
selectedPresetName,
model_config_name,
custom_preset_config,
selectedPlugins,
plugins,
custom_presets,
custom_instructions,
} = cache;
const model = modelInfos.find((model) => model.name === model_config_name);

if (model) {
Expand All @@ -259,6 +290,11 @@ const useHydrateChatConfig = ({
setCustomPresets(custom_presets);
}

if (custom_instructions) {
setCustomInstructions(custom_instructions);
setValue("custom_instructions", custom_instructions);
}

if (selectedPlugins && selectedPlugins.length > 0) {
setValue("plugins", selectedPlugins);
const preset = (model || modelInfos[0]).parameter_configs.find(
Expand Down
Loading

0 comments on commit 7a68b59

Please sign in to comment.