Skip to content

Commit

Permalink
refactor(anthropic): create helper functions for wrangling attributes (
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy authored Oct 30, 2024
1 parent e8b4b51 commit 877e457
Showing 1 changed file with 110 additions and 98 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC
from contextlib import contextmanager
from typing import Any, Callable, Dict, Iterator, List, Mapping, Tuple
from itertools import chain
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Mapping, Optional, Tuple

import opentelemetry.context as context_api
from opentelemetry import trace as trace_api
Expand All @@ -24,6 +25,11 @@
ToolCallAttributes,
)

if TYPE_CHECKING:
from pydantic import BaseModel

from anthropic.types import Usage


class _WithTracer(ABC):
"""
Expand All @@ -36,15 +42,14 @@ def __init__(self, tracer: trace_api.Tracer, *args: Any, **kwargs: Any) -> None:

@contextmanager
def _start_as_current_span(
self,
span_name: str,
self, span_name: str, attributes: Optional[Mapping[str, Any]] = None
) -> Iterator[_WithSpan]:
# Because OTEL has a default limit of 128 attributes, we split our attributes into
# two tiers, where the addition of "extra_attributes" is deferred until the end
# and only after the "attributes" are added.
try:
span = self._tracer.start_span(
name=span_name, record_exception=False, set_status_on_exception=False
name=span_name,
record_exception=False,
set_status_on_exception=False,
attributes=attributes,
)
except Exception:
span = INVALID_SPAN
Expand Down Expand Up @@ -81,24 +86,21 @@ def __call__(
llm_prompt = dict(arguments).pop("prompt", None)
llm_invocation_parameters = _get_invocation_parameters(arguments)

span_name = "Completions"
with self._start_as_current_span(
span_name,
span_name="Completions",
attributes=dict(
chain(
get_attributes_from_context(),
_get_llm_model(arguments),
_get_llm_provider(),
_get_llm_system(),
_get_llm_span_kind(),
_get_llm_prompts(llm_prompt),
_get_inputs(arguments),
_get_llm_invocation_parameters(llm_invocation_parameters),
)
),
) as span:
span.set_attributes(dict(get_attributes_from_context()))

span.set_attributes(
{
**dict(_get_llm_model(arguments)),
**dict(_get_llm_provider()),
**dict(_get_llm_system()),
OPENINFERENCE_SPAN_KIND: LLM,
LLM_PROMPTS: [llm_prompt],
INPUT_VALUE: safe_json_dumps(arguments),
INPUT_MIME_TYPE: JSON,
LLM_INVOCATION_PARAMETERS: safe_json_dumps(llm_invocation_parameters),
}
)
try:
response = wrapped(*args, **kwargs)
except Exception as exception:
Expand All @@ -111,12 +113,7 @@ def __call__(
return _Stream(response, span)
else:
span.set_status(trace_api.StatusCode.OK)
span.set_attributes(
{
OUTPUT_VALUE: response.model_dump_json(),
OUTPUT_MIME_TYPE: JSON,
}
)
span.set_attributes(dict(_get_outputs(response)))
span.finish_tracing()
return response

Expand All @@ -141,24 +138,21 @@ async def __call__(
llm_prompt = dict(arguments).pop("prompt", None)
invocation_parameters = _get_invocation_parameters(arguments)

span_name = "AsyncCompletions"
with self._start_as_current_span(
span_name,
span_name="AsyncCompletions",
attributes=dict(
chain(
get_attributes_from_context(),
_get_llm_model(arguments),
_get_llm_provider(),
_get_llm_system(),
_get_llm_span_kind(),
_get_llm_prompts(llm_prompt),
_get_inputs(arguments),
_get_llm_invocation_parameters(invocation_parameters),
)
),
) as span:
span.set_attributes(dict(get_attributes_from_context()))

span.set_attributes(
{
**dict(_get_llm_model(arguments)),
**dict(_get_llm_provider()),
**dict(_get_llm_system()),
OPENINFERENCE_SPAN_KIND: LLM,
LLM_PROMPTS: [llm_prompt],
INPUT_VALUE: safe_json_dumps(arguments),
INPUT_MIME_TYPE: JSON,
LLM_INVOCATION_PARAMETERS: safe_json_dumps(invocation_parameters),
}
)
try:
response = await wrapped(*args, **kwargs)
except Exception as exception:
Expand All @@ -171,12 +165,7 @@ async def __call__(
return _Stream(response, span)
else:
span.set_status(trace_api.StatusCode.OK)
span.set_attributes(
{
OUTPUT_VALUE: response.to_json(indent=None),
OUTPUT_MIME_TYPE: JSON,
}
)
span.set_attributes(dict(_get_outputs(response)))
span.finish_tracing()
return response

Expand All @@ -201,24 +190,21 @@ def __call__(
llm_input_messages = dict(arguments).pop("messages", None)
invocation_parameters = _get_invocation_parameters(arguments)

span_name = "Messages"
with self._start_as_current_span(
span_name,
span_name="Messages",
attributes=dict(
chain(
get_attributes_from_context(),
_get_llm_model(arguments),
_get_llm_provider(),
_get_llm_system(),
_get_llm_span_kind(),
_get_llm_input_messages(llm_input_messages),
_get_llm_invocation_parameters(invocation_parameters),
_get_inputs(arguments),
)
),
) as span:
span.set_attributes(dict(get_attributes_from_context()))

span.set_attributes(
{
**dict(_get_llm_model(arguments)),
**dict(_get_llm_provider()),
**dict(_get_llm_system()),
OPENINFERENCE_SPAN_KIND: LLM,
**dict(_get_input_messages(llm_input_messages)),
LLM_INVOCATION_PARAMETERS: safe_json_dumps(invocation_parameters),
INPUT_VALUE: safe_json_dumps(arguments),
INPUT_MIME_TYPE: JSON,
}
)
try:
response = wrapped(*args, **kwargs)
except Exception as exception:
Expand All @@ -231,13 +217,13 @@ def __call__(
else:
span.set_status(trace_api.StatusCode.OK)
span.set_attributes(
{
**dict(_get_output_messages(response)),
LLM_TOKEN_COUNT_PROMPT: response.usage.input_tokens,
LLM_TOKEN_COUNT_COMPLETION: response.usage.output_tokens,
OUTPUT_VALUE: response.model_dump_json(),
OUTPUT_MIME_TYPE: JSON,
}
dict(
chain(
_get_output_messages(response),
_get_llm_token_counts(response.usage),
_get_outputs(response),
)
)
)
span.finish_tracing()
return response
Expand All @@ -263,24 +249,21 @@ async def __call__(
llm_input_messages = dict(arguments).pop("messages", None)
invocation_parameters = _get_invocation_parameters(arguments)

span_name = "AsyncMessages"
with self._start_as_current_span(
span_name,
span_name="AsyncMessages",
attributes=dict(
chain(
get_attributes_from_context(),
_get_llm_provider(),
_get_llm_system(),
_get_llm_model(arguments),
_get_llm_span_kind(),
_get_llm_input_messages(llm_input_messages),
_get_llm_invocation_parameters(invocation_parameters),
_get_inputs(arguments),
)
),
) as span:
span.set_attributes(dict(get_attributes_from_context()))

span.set_attributes(
{
**dict(_get_llm_provider()),
**dict(_get_llm_system()),
**dict(_get_llm_model(arguments)),
OPENINFERENCE_SPAN_KIND: LLM,
**dict(_get_input_messages(llm_input_messages)),
LLM_INVOCATION_PARAMETERS: safe_json_dumps(invocation_parameters),
INPUT_VALUE: safe_json_dumps(arguments),
INPUT_MIME_TYPE: JSON,
}
)
try:
response = await wrapped(*args, **kwargs)
except Exception as exception:
Expand All @@ -293,18 +276,32 @@ async def __call__(
else:
span.set_status(trace_api.StatusCode.OK)
span.set_attributes(
{
**dict(_get_output_messages(response)),
LLM_TOKEN_COUNT_PROMPT: response.usage.input_tokens,
LLM_TOKEN_COUNT_COMPLETION: response.usage.output_tokens,
OUTPUT_VALUE: response.model_dump_json(),
OUTPUT_MIME_TYPE: JSON,
}
dict(
chain(
_get_output_messages(response),
_get_llm_token_counts(response.usage),
_get_outputs(response),
)
)
)
span.finish_tracing()
return response


def _get_inputs(arguments: Mapping[str, Any]) -> Iterator[Tuple[str, Any]]:
yield INPUT_VALUE, safe_json_dumps(arguments)
yield INPUT_MIME_TYPE, JSON


def _get_outputs(response: "BaseModel") -> Iterator[Tuple[str, Any]]:
yield OUTPUT_VALUE, response.model_dump_json()
yield OUTPUT_MIME_TYPE, JSON


def _get_llm_span_kind() -> Iterator[Tuple[str, Any]]:
yield OPENINFERENCE_SPAN_KIND, LLM


def _get_llm_provider() -> Iterator[Tuple[str, Any]]:
yield LLM_PROVIDER, LLM_PROVIDER_ANTHROPIC

Expand All @@ -313,12 +310,27 @@ def _get_llm_system() -> Iterator[Tuple[str, Any]]:
yield LLM_SYSTEM, LLM_SYSTEM_ANTHROPIC


def _get_llm_token_counts(usage: "Usage") -> Iterator[Tuple[str, Any]]:
yield LLM_TOKEN_COUNT_PROMPT, usage.input_tokens
yield LLM_TOKEN_COUNT_COMPLETION, usage.output_tokens


def _get_llm_model(arguments: Mapping[str, Any]) -> Iterator[Tuple[str, Any]]:
if model_name := arguments.get("model"):
yield LLM_MODEL_NAME, model_name


def _get_input_messages(messages: List[Dict[str, str]]) -> Any:
def _get_llm_invocation_parameters(
invocation_parameters: Mapping[str, Any],
) -> Iterator[Tuple[str, Any]]:
yield LLM_INVOCATION_PARAMETERS, safe_json_dumps(invocation_parameters)


def _get_llm_prompts(prompt: str) -> Iterator[Tuple[str, Any]]:
yield LLM_PROMPTS, [prompt]


def _get_llm_input_messages(messages: List[Dict[str, str]]) -> Any:
"""
Extracts the messages from the chat response
"""
Expand Down

0 comments on commit 877e457

Please sign in to comment.