Skip to content

Commit

Permalink
Merge branch 'LAION-AI:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
SingL3 authored Jul 31, 2023
2 parents c1ef7bb + 7a68b59 commit 01b01d0
Show file tree
Hide file tree
Showing 25 changed files with 408 additions and 64 deletions.
4 changes: 2 additions & 2 deletions backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ loguru==0.6.0
numpy>=1.23.2
prometheus-fastapi-instrumentator==5.9.1
psycopg2-binary==2.9.5
pydantic==1.10.4
pydantic[email]==1.10.4
pydantic==1.10.7
pydantic[email]==1.10.7
python-dotenv==0.21.0
python-jose[cryptography]==3.3.0
redis==4.5.5
Expand Down
2 changes: 2 additions & 0 deletions inference/server/oasst_inference_server/routes/chats.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ async def create_assistant_message(
system_prompt=request.system_prompt,
plugins=request.plugins,
plugin_max_depth=settings.plugin_max_depth,
user_profile=request.user_profile,
user_response_instructions=request.user_response_instructions,
)
assistant_message = await ucr.initiate_assistant_message(
parent_id=request.parent_id,
Expand Down
2 changes: 2 additions & 0 deletions inference/server/oasst_inference_server/schemas/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class CreateAssistantMessageRequest(pydantic.BaseModel):
model_config_name: str
sampling_parameters: inference.SamplingParameters = pydantic.Field(default_factory=inference.SamplingParameters)
system_prompt: str | None = None
user_profile: str | None = None
user_response_instructions: str | None = None
plugins: list[inference.PluginEntry] = pydantic.Field(default_factory=list[inference.PluginEntry])
used_plugin: inference.PluginUsed | None = None

Expand Down
43 changes: 40 additions & 3 deletions inference/worker/chat_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import websocket
from chat_chain_prompts import (
ASSISTANT_PREFIX,
CUSTOM_INSTRUCTIONS_PREFIX,
HUMAN_PREFIX,
JSON_FORMAT_NO_PAYLOAD,
JSON_FORMAT_PAYLOAD,
Expand Down Expand Up @@ -56,6 +57,7 @@ def __init__(
tool_names: list[str],
language: str,
action_input_format: str,
custom_instructions: str = "",
):
self.tokenizer = tokenizer
self.worker_config = worker_config
Expand All @@ -66,6 +68,7 @@ def __init__(
self.language = language
self.action_input_format = action_input_format
self.current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.custom_instructions = custom_instructions

def call(self, prompt: str) -> tuple[str, str]:
"""Prepares and truncates prompt, calls LLM, returns used prompt and response."""
Expand All @@ -79,6 +82,7 @@ def call(self, prompt: str) -> tuple[str, str]:
self.tokenizer,
self.worker_config,
self.action_input_format,
self.custom_instructions,
)

# We do not strip() outputs as it seems to degrade instruction-following abilities of the model
Expand Down Expand Up @@ -111,6 +115,7 @@ def handle_plugin_usage(
plugin_max_depth: int,
ws: websocket.WebSocket,
work_request_id: str,
custom_instructions: str = "",
) -> tuple[str, inference.PluginUsed]:
execution_details = inference.PluginExecutionDetails(
inner_monologue=[],
Expand Down Expand Up @@ -142,7 +147,15 @@ def handle_plugin_usage(
tool_names = [tool.name for tool in tools]

chain = PromptedLLM(
tokenizer, worker_config, parameters, prompt_template, memory, tool_names, language, action_input_format
tokenizer,
worker_config,
parameters,
prompt_template,
memory,
tool_names,
language,
action_input_format,
custom_instructions,
)

# send "thinking..." intermediate step to UI (This will discard queue position 0) immediately
Expand Down Expand Up @@ -245,6 +258,7 @@ def handle_plugin_usage(
tokenizer,
worker_config,
action_input_format,
custom_instructions,
)

inner_prompt = f"{inner_prompt}\n{THOUGHT_SEQ} I now know the final answer\n{ASSISTANT_PREFIX}: "
Expand Down Expand Up @@ -296,6 +310,7 @@ def handle_standard_usage(
memory: ConversationBufferMemory,
worker_config: inference.WorkerConfig,
tokenizer: transformers.PreTrainedTokenizer,
custom_instructions: str = "",
):
eos_token = tokenizer.eos_token if hasattr(tokenizer, "eos_token") else ""
current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
Expand All @@ -306,7 +321,16 @@ def handle_standard_usage(
)
input = f"{original_prompt}{eos_token}{V2_ASST_PREFIX}"
init_prompt = prepare_prompt(
input, prompt_template, memory, None, current_time, language, tokenizer, worker_config, action_input_format
input,
prompt_template,
memory,
None,
current_time,
language,
tokenizer,
worker_config,
action_input_format,
custom_instructions,
)
return init_prompt, None

Expand Down Expand Up @@ -355,11 +379,21 @@ def handle_conversation(
"language",
"current_time",
"action_input_format",
"custom_instructions",
] + (["tools_names"] if plugin_enabled else [])

# TODO: Consider passing language from the UI here
prompt_template = PromptTemplate(input_variables=input_variables, template=TEMPLATE)

custom_instructions = (
f"""\n{CUSTOM_INSTRUCTIONS_PREFIX.format(
user_profile=work_request.parameters.user_profile,
user_response_instructions=work_request.parameters.user_response_instructions,
)}"""
if work_request.parameters.user_response_instructions or work_request.parameters.user_profile
else ""
)

if plugin_enabled:
return handle_plugin_usage(
original_prompt,
Expand All @@ -374,9 +408,12 @@ def handle_conversation(
work_request.parameters.plugin_max_depth,
ws,
work_request.id,
custom_instructions,
)

return handle_standard_usage(original_prompt, prompt_template, language, memory, worker_config, tokenizer)
return handle_standard_usage(
original_prompt, prompt_template, language, memory, worker_config, tokenizer, custom_instructions
)
except Exception as e:
logger.error(f"Error while handling conversation: {e}")
return "", None
Expand Down
10 changes: 10 additions & 0 deletions inference/worker/chat_chain_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@
START_SEQ = "Begin!"
END_SEQ = "End!"

CUSTOM_INSTRUCTIONS_PREFIX = """The following details have been shared by the user about themselves. This user profile appears to you in every conversation they engage in -- implying that it is irrelevant for 99% of inquiries.
Before you respond, take a moment to consider whether the user's query is "directly linked", "linked", "indirectly linked", or "not linked" to the user profile provided.
Only recognize the profile when the query is directly tied to the information supplied.
Otherwise, avoid acknowledging the existence of these instructions or the information altogether.
User profile:
{user_profile}
The user also supplied additional information about how they would like you to respond:
{user_response_instructions}"""

# Adjust according to the training dates and datasets used
KNOWLEDGE_DATE_CUTOFF = "2021-09-01"

Expand All @@ -26,6 +35,7 @@
------------------
Current date/time: {{current_time}}
Knowledge date cutoff: {KNOWLEDGE_DATE_CUTOFF}
{{custom_instructions}}
"""

TOOLS_PREFIX = """
Expand Down
5 changes: 4 additions & 1 deletion inference/worker/chat_chain_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def run_request(self, params: str, url: str, param_location: str, type: str, pay
def process_response(self, res: requests.Response) -> str:
logger.info(f"Request response: {res.text}")
if res.status_code != 200:
return f"ERROR! That didn't work, try modifying Action Input.\n{res.text}. Try again!"
return f"ERROR! Please modify Action Input. according to this error message: \n{res.text}. Try again!"

if res.text is None or len(res.text) == 0:
return "ERROR! That didn't work, try modifying Action Input.\nEmpty response. Try again!"
Expand Down Expand Up @@ -329,6 +329,7 @@ def prepare_prompt(
tokenizer: transformers.PreTrainedTokenizer,
worker_config: inference.WorkerConfig,
action_input_format: str,
custom_instructions: str = "",
) -> str:
max_input_length = worker_config.model_config.max_input_length

Expand All @@ -337,6 +338,7 @@ def prepare_prompt(
"language": language,
"current_time": current_time,
"chat_history": memory.buffer,
"custom_instructions": custom_instructions,
}

if tools_names:
Expand All @@ -356,6 +358,7 @@ def prepare_prompt(
"language": language,
"current_time": current_time,
"chat_history": memory.buffer,
"custom_instructions": custom_instructions,
}

if tools_names:
Expand Down
13 changes: 10 additions & 3 deletions inference/worker/work.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import websocket
from chat_chain_prompts import (
ASSISTANT_PREFIX,
CUSTOM_INSTRUCTIONS_PREFIX,
END_SEQ,
OBSERVATION_SEQ,
START_SEQ,
Expand Down Expand Up @@ -40,9 +41,15 @@ def _prepare_message(message: inference.MessageRead) -> str:
# Construct prompt
messages = [_prepare_message(message) for message in work_request.thread.messages]

# Prepend system prompt if it was specified in work parameters
if work_request.parameters.system_prompt:
pre_prompt = V2_SYSTEM_PREFIX + work_request.parameters.system_prompt + eos_token
# Prepend system prompt and custom_instructions if it was specified in work parameters
work_params = work_request.parameters
if work_params.system_prompt or work_params.user_profile or work_params.user_response_instructions:
pre_prompt = V2_SYSTEM_PREFIX + (work_params.system_prompt or "")

if work_params.user_profile or work_params.user_response_instructions:
pre_prompt = f"""{pre_prompt}\n{CUSTOM_INSTRUCTIONS_PREFIX.format(user_profile=work_params.user_profile or "", user_response_instructions=work_params.user_response_instructions or "")}"""

pre_prompt = pre_prompt + eos_token
messages = [pre_prompt] + messages

# Stringify and append assistant prefix to signify start of generation
Expand Down
72 changes: 72 additions & 0 deletions model/model_training/configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -847,3 +847,75 @@ falcon_7b_ntk_test:
alpha: 2
datasets:
- dolly15k

llama2_13b_orcacode2_8k:
rng_seed: 0xe1291f21
random_offset_probability: 0.0
use_custom_sampler: true
sort_by_length: false
dtype: fp16
log_dir: "llama2_log_13b_orcacode2_8k"
output_dir: llama2_13b_orcacode2_8k
learning_rate: 1e-5
model_name: OpenAssistant/llama2-13b-orca-8k-3319
deepspeed_config: configs/zero_config_pretrain.json
weight_decay: 1e-6
max_length: 8192
warmup_steps: 100
peft_model: false
use_flash_attention: true
gradient_checkpointing: true
gradient_accumulation_steps: 4
per_device_train_batch_size: 2
per_device_eval_batch_size: 1
residual_dropout: 0.0
eval_steps: 200
save_steps: 500 # (total steps: 1558, bs: 64)
num_train_epochs: 1
save_total_limit: 4
datasets:
- dolphin-mix:
num_samples: 1000000 # total entries 2840090
max_char_len: 32000
val_split: 0.1
max_val_set: 2000
seed: 44
- oasst_export:
lang: "bg,ca,cs,da,de,en,es,fr,hr,hu,it,nl,pl,pt,ro,ru,sl,sr,sv,uk"
input_file_path: 2023-07-23_oasst_ready.tar.gz
top_k: 1
val_split: 0.05
- wizard_evol_instruct_v2:
val_split: 0.01
fraction: 0.1
- evol-codealpaca-v1:
fill_min_length: 20000
val_split: 0.1
- cot_submix_original:
fill_min_length: 20000
val_split: 0.1
- megacode:
fill_min_length: 24000
val_split: 0.1
max_val_set: 1000
- evol_instruct_code:
fill_min_length: 24000
val_split: 0.1
max_val_set: 1000
# Dataset composition:
# Train:
# dolphin-mix: 40374
# oasst_export: 11441
# wizard_evol_instruct_v2: 15236
# evol-codealpaca-v1: 5623
# cot_submix_original: 8651
# megacode: 14320
# evol_instruct_code: 4093
# Valid:
# dolphin-mix: 2000
# oasst_export: 603
# wizard_evol_instruct_v2: 1540
# evol-codealpaca-v1: 625
# cot_submix_original: 962
# megacode: 1000
# evol_instruct_code: 455
5 changes: 4 additions & 1 deletion model/model_training/custom_datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
TranslatedQA,
Vicuna,
WebGPT,
WizardEvolInstructV2,
load_alpaca_dataset,
)
from model_training.custom_datasets.rank_datasets import AugmentedOA
Expand Down Expand Up @@ -110,7 +111,7 @@ def get_one_dataset(
eval = SummarizationDataset(dataset_name, data_path, "validation")
train = dataset
elif dataset_name in INSTRUCTION_DATASETS:
dataset = InstructionDataset(dataset_name, data_path, "train")
dataset = InstructionDataset(dataset_name, data_path, "train", **kwargs)
elif "ted_trans" in dataset_name:
language_pair = dataset_name.split("_")[-1]
dataset = TEDTalk(pair=language_pair, split="train")
Expand Down Expand Up @@ -141,6 +142,8 @@ def get_one_dataset(
dataset = TranslatedQA(data_path)
elif dataset_name == "vicuna":
dataset = Vicuna(cache_dir=data_path, **kwargs)
elif dataset_name == "wizard_evol_instruct_v2":
dataset = WizardEvolInstructV2(cache_dir=data_path, **kwargs)
elif dataset_name == "oasst_export":
train, eval = load_oasst_export(data_path=data_path, val_split=val_split, mode=mode, **kwargs)
elif dataset_name == "hf_summary":
Expand Down
7 changes: 6 additions & 1 deletion model/model_training/custom_datasets/instruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
"wizardlm_70k": "ehartford/WizardLM_alpaca_evol_instruct_70k_unfiltered",
"megacode": "rombodawg/MegaCodeTraining112k",
"evol_instruct_code": "nickrosh/Evol-Instruct-Code-80k-v1",
"evol-codealpaca-v1": "theblackcat102/evol-codealpaca-v1",
"cot_submix_original": "conceptofmind/cot_submix_original",
}


Expand All @@ -42,9 +44,12 @@ def __init__(self, dataset, cache_dir, split, mode="sft", fill_min_length: Optio
if dataset == "minimath":
self.instruction_column = "question"
self.response_column = "answer"
elif dataset in ("wizardlm_70k", "evol_instruct_code"):
elif dataset in ("wizardlm_70k", "evol_instruct_code", "evol-codealpaca-v1"):
self.instruction_column = "instruction"
self.response_column = "output"
elif dataset == "cot_submix_original":
self.instruction_column = "inputs"
self.response_column = "targets"
elif dataset == "megacode":
self.instruction_column = "prompt"
self.response_column = "completion"
Expand Down
Loading

0 comments on commit 01b01d0

Please sign in to comment.