Skip to content

Commit

Permalink
Patching Anthropic System (#1130)
Browse files Browse the repository at this point in the history
Co-authored-by: Richie Caputo <[email protected]>
Co-authored-by: Richie Caputo <[email protected]>
  • Loading branch information
3 people authored Oct 29, 2024
1 parent 651465b commit da537a6
Show file tree
Hide file tree
Showing 10 changed files with 443 additions and 62 deletions.
8 changes: 1 addition & 7 deletions .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,7 @@ jobs:
python3 -m pip install -r requirements.txt
python3 -m pip install -r requirements-doc.txt
- name: Run Continuous Integration Action
run: |
set -e -o pipefail
export CUSTOM_PACKAGES="${{ env.CUSTOM_PACKAGES }}" &&
export CUSTOM_FLAGS="${{ env.CUSTOM_FLAGS }}" &&
curl -sSL https://raw.githubusercontent.com/gao-hongnan/omniverse/2fd5de1b8103e955cd5f022ab016b72fa901fa8f/scripts/devops/continuous-integration/lint_ruff.sh -o lint_ruff.sh
chmod +x lint_ruff.sh
bash lint_ruff.sh | tee ${{ env.WORKING_DIRECTORY }}/${{ env.RUFF_OUTPUT_FILENAME }}
uses: astral-sh/ruff-action@v1
- name: Upload Artifacts
uses: actions/upload-artifact@v3
with:
Expand Down
4 changes: 2 additions & 2 deletions docs/hooks/hide_lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

@mkdocs.plugins.event_priority(0)
# pylint: disable=unused-argument
def on_startup(command: str, dirty: bool) -> None:
def on_startup(command: str, dirty: bool) -> None: # noqa: ARG001
"""Monkey patch Highlight extension to hide lines in code blocks."""
original = highlight.Highlight.highlight
original = highlight.Highlight.highlight # type: ignore

def patched(self: Any, src: str, *args: Any, **kwargs: Any) -> Any:
lines = src.splitlines(keepends=True)
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/7-synthetic-data-generation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@
"\n",
"for metric,size in product(METRICS,SIZES):\n",
" metric_name, score_fn = metric\n",
" score_fns[f\"{metric_name}@{size}\"] = lambda predictions,labels : score_fn(predictions[:size],labels)"
" score_fns[f\"{metric_name}@{size}\"] = lambda predictions,labels, fn=score_fn, k=size: fn(predictions[:k],labels) # type: ignore"
]
},
{
Expand Down
7 changes: 1 addition & 6 deletions instructor/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,12 +507,7 @@ def from_openai(
instructor.Mode.MD_JSON,
}

if provider in {Provider.DATABRICKS}:
assert mode in {
instructor.Mode.MD_JSON
}, "Databricks provider only supports `MD_JSON` mode."

if provider in {Provider.OPENAI}:
if provider in {Provider.OPENAI, Provider.DATABRICKS}:
assert mode in {
instructor.Mode.TOOLS,
instructor.Mode.JSON,
Expand Down
55 changes: 23 additions & 32 deletions instructor/process_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
from instructor.dsl.partial import PartialBase
from instructor.dsl.simple_type import AdapterBase, ModelAdapter, is_simple_type
from instructor.function_calls import OpenAISchema, openai_schema
from instructor.utils import merge_consecutive_messages
from instructor.utils import (
merge_consecutive_messages,
extract_system_messages,
combine_system_messages,
)
from instructor.multimodal import convert_messages

logger = logging.getLogger("instructor")
Expand Down Expand Up @@ -332,20 +336,15 @@ def handle_anthropic_tools(
"name": response_model.__name__,
}

system_messages = [
m["content"] for m in new_kwargs["messages"] if m["role"] == "system"
]
system_messages = extract_system_messages(new_kwargs.get("messages", []))

if "system" in new_kwargs and system_messages:
raise ValueError(
"Only a single system message is supported - either set it as a message in the messages array or use the system parameter"
if system_messages:
new_kwargs["system"] = combine_system_messages(
new_kwargs.get("system"), system_messages
)

if "system" not in new_kwargs:
new_kwargs["system"] = "\n\n".join(system_messages)

new_kwargs["messages"] = [
m for m in new_kwargs["messages"] if m["role"] != "system"
m for m in new_kwargs.get("messages", []) if m["role"] != "system"
]

return response_model, new_kwargs
Expand All @@ -354,25 +353,18 @@ def handle_anthropic_tools(
def handle_anthropic_json(
response_model: type[T], new_kwargs: dict[str, Any]
) -> tuple[type[T], dict[str, Any]]:
openai_system_messages = [
message["content"]
for message in new_kwargs.get("messages", [])
if message["role"] == "system"
]
system_messages = extract_system_messages(new_kwargs.get("messages", []))

if "system" in new_kwargs and openai_system_messages:
raise ValueError(
"Only a single System message is supported - either set it using the system parameter or in the list of messages"
if system_messages:
new_kwargs["system"] = combine_system_messages(
new_kwargs.get("system"), system_messages
)

if not "system" in new_kwargs:
new_kwargs["system"] = "\n\n".join(openai_system_messages)

new_kwargs["messages"] = [
m for m in new_kwargs["messages"] if m["role"] != "system"
m for m in new_kwargs.get("messages", []) if m["role"] != "system"
]

message = dedent(
json_schema_message = dedent(
f"""
As a genius expert, your task is to understand the content and provide
the parsed objects in json that match the following json_schema:\n
Expand All @@ -383,7 +375,9 @@ def handle_anthropic_json(
"""
)

new_kwargs["system"] = f"{new_kwargs.get('system', '')}\n\n{message}".strip()
new_kwargs["system"] = combine_system_messages(
new_kwargs.get("system"), [{"type": "text", "text": json_schema_message}]
)

return response_model, new_kwargs

Expand Down Expand Up @@ -664,22 +658,19 @@ def handle_response_model(
# This is cause cohere uses 'message' and 'chat_history' instead of 'messages'
return handle_cohere_modes(new_kwargs)
# Handle images without a response model
if autodetect_images and "messages" in new_kwargs:
if "messages" in new_kwargs:
messages = convert_messages(
new_kwargs["messages"],
mode,
autodetect_images=autodetect_images,
)
if mode in {Mode.ANTHROPIC_JSON, Mode.ANTHROPIC_TOOLS}:
# Handle OpenAI style or Anthropic style messages
new_kwargs["messages"] = [
m for m in messages if m["role"] != "system"
]
new_kwargs["messages"] = [m for m in messages if m["role"] != "system"]
if "system" not in new_kwargs:
system_messages = (m for m in messages if m["role"] == "system")
system_message = next(system_messages, None)
system_message = extract_system_messages(messages)
if system_message:
new_kwargs["system"] = system_message["content"]
new_kwargs["system"] = system_message
else:
new_kwargs["messages"] = messages
return None, new_kwargs
Expand Down
53 changes: 52 additions & 1 deletion instructor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
Callable,
Generic,
Protocol,
Union,
TypedDict,
TypeVar,
cast
)
from pydantic import BaseModel
import os
Expand Down Expand Up @@ -131,7 +134,6 @@ def update_total_usage(
response: T_Model | None,
total_usage: OpenAIUsage | AnthropicUsage,
) -> T_Model | ChatCompletion | None:

if response is None:
return None

Expand Down Expand Up @@ -369,3 +371,52 @@ def update_gemini_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]:

def disable_pydantic_error_url():
os.environ["PYDANTIC_ERRORS_INCLUDE_URL"] = "0"


class SystemMessage(TypedDict, total=False):
type: str
text: str
cache_control: dict[str, str]


def combine_system_messages(
existing_system: Union[str, list[SystemMessage], None], # noqa: UP007
new_system: Union[str, list[SystemMessage]], # noqa: UP007
) -> Union[str, list[SystemMessage]]: # noqa: UP007
if existing_system is None:
return new_system

if isinstance(existing_system, str) and isinstance(new_system, str):
return f"{existing_system}\n\n{new_system}"

if isinstance(existing_system, list) and isinstance(new_system, list):
return existing_system + new_system

if isinstance(existing_system, str) and isinstance(new_system, list):
return [SystemMessage(type="text", text=existing_system)] + new_system

if isinstance(existing_system, list) and isinstance(new_system, str):
return existing_system + [SystemMessage(type="text", text=new_system)]

raise ValueError("Unsupported system message type combination")


def extract_system_messages(messages: list[dict[str, Any]]) -> list[SystemMessage]:
def convert_message(content: Union[str, dict[str, Any]]) -> SystemMessage: # noqa: UP007
if isinstance(content, str):
return SystemMessage(type="text", text=content)
elif isinstance(content, dict):
return SystemMessage(**content)
else:
raise ValueError(f"Unsupported content type: {type(content)}")

result: list[SystemMessage] = []
for m in messages:
if m["role"] == "system":
# System message must always be a string or list of dictionaries
content = cast(Union[str, list[dict[str, Any]]], m["content"]) # noqa: UP007
if isinstance(content, list):
result.extend(convert_message(item) for item in content)
else:
result.append(convert_message(content))
return result
14 changes: 7 additions & 7 deletions make_desc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Optional, List, Set, Literal
from typing import Optional, Literal
import asyncio
from openai import AsyncOpenAI
import typer
Expand All @@ -15,7 +15,7 @@


async def generate_ai_frontmatter(
client: AsyncOpenAI, title: str, content: str, categories: List[str]
client: AsyncOpenAI, title: str, content: str, categories: list[str]
):
"""
Generate a description and categories for the given content using AI.
Expand All @@ -35,8 +35,8 @@ class DescriptionAndCategories(BaseModel):
reasoning: str = Field(
..., description="The reasoning for the correct categories"
)
tags: List[str]
categories: List[
tags: list[str]
categories: list[
Literal[
"OpenAI",
"Anthropic",
Expand Down Expand Up @@ -72,7 +72,7 @@ class DescriptionAndCategories(BaseModel):
return response


def get_all_categories(root_dir: str) -> Set[str]:
def get_all_categories(root_dir: str) -> set[str]:
"""
Read all markdown files and extract unique categories.
Expand Down Expand Up @@ -113,7 +113,7 @@ def preview_categories(root_dir: str) -> None:


async def process_file(
client: AsyncOpenAI, file_path: str, categories: List[str], enable_comments: bool
client: AsyncOpenAI, file_path: str, categories: list[str], enable_comments: bool
) -> None:
"""
Process a single file, adding or updating the description and categories in the front matter.
Expand Down Expand Up @@ -143,7 +143,7 @@ async def process_file(

async def process_files(
root_dir: str,
api_key: Optional[str] = None,
api_key: Optional[str] = None, # noqa: ARG001
use_categories: bool = False,
enable_comments: bool = False,
) -> None:
Expand Down
17 changes: 11 additions & 6 deletions tests/llm/test_anthropic/evals/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,19 @@ class User(BaseModel):

@field_validator("name")
def name_is_uppercase(cls, v: str):
assert v.isupper(), "Name must be uppercase, please fix"
assert v.isupper(), f"{v} is not an uppercased string. Note that all characters in {v} must be uppercase (EG. TIM SARAH ADAM)."
return v

resp = client.messages.create(
model="claude-3-haiku-20240307",
max_tokens=1024,
max_tokens=4096,
max_retries=2,
system="Make sure to follow the instructions carefully and return a response object that matches the json schema requested. Age is an integer.",
messages=[
{
"role": "user",
"content": "Extract John is 18 years old.",
}
},
],
response_model=User,
) # type: ignore
Expand All @@ -53,7 +54,7 @@ class User(BaseModel):

resp = client.messages.create(
model="claude-3-haiku-20240307",
max_tokens=1024,
max_tokens=4096,
max_retries=0,
messages=[
{
Expand Down Expand Up @@ -83,6 +84,7 @@ class User(BaseModel):
model="claude-3-haiku-20240307",
max_tokens=1024,
max_retries=0,
system="Make sure to follow the instructions carefully and return a response object that matches the json schema requested. Family members here is just asking for a list of names",
messages=[
{
"role": "user",
Expand Down Expand Up @@ -132,7 +134,7 @@ class User(BaseModel):

resp = client.messages.create(
model="claude-3-haiku-20240307",
max_tokens=1024,
max_tokens=4096,
max_retries=2,
messages=[
{
Expand Down Expand Up @@ -185,7 +187,10 @@ class User(BaseModel):
max_tokens=1024,
max_retries=0,
messages=[
{"role": "system", "content": "EVERYTHING MUST BE IN ALL CAPS"},
{
"role": "system",
"content": "Please make sure to follow the instructions carefully and return a valid response object. All strings must be fully capitalised in all caps. (Eg. THIS IS AN UPPERCASE STRING) and age is an integer.",
},
{
"role": "user",
"content": "Create a user for a model with a name and age.",
Expand Down
Loading

0 comments on commit da537a6

Please sign in to comment.