From 877e457f0e44ff76ed8811b5f6885a42bff81ef4 Mon Sep 17 00:00:00 2001 From: Xander Song Date: Wed, 30 Oct 2024 08:48:53 -0700 Subject: [PATCH] refactor(anthropic): create helper functions for wrangling attributes (#1086) --- .../instrumentation/anthropic/_wrappers.py | 208 +++++++++--------- 1 file changed, 110 insertions(+), 98 deletions(-) diff --git a/python/instrumentation/openinference-instrumentation-anthropic/src/openinference/instrumentation/anthropic/_wrappers.py b/python/instrumentation/openinference-instrumentation-anthropic/src/openinference/instrumentation/anthropic/_wrappers.py index fefb2d564..c5d0e8204 100644 --- a/python/instrumentation/openinference-instrumentation-anthropic/src/openinference/instrumentation/anthropic/_wrappers.py +++ b/python/instrumentation/openinference-instrumentation-anthropic/src/openinference/instrumentation/anthropic/_wrappers.py @@ -1,6 +1,7 @@ from abc import ABC from contextlib import contextmanager -from typing import Any, Callable, Dict, Iterator, List, Mapping, Tuple +from itertools import chain +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Mapping, Optional, Tuple import opentelemetry.context as context_api from opentelemetry import trace as trace_api @@ -24,6 +25,11 @@ ToolCallAttributes, ) +if TYPE_CHECKING: + from pydantic import BaseModel + + from anthropic.types import Usage + class _WithTracer(ABC): """ @@ -36,15 +42,14 @@ def __init__(self, tracer: trace_api.Tracer, *args: Any, **kwargs: Any) -> None: @contextmanager def _start_as_current_span( - self, - span_name: str, + self, span_name: str, attributes: Optional[Mapping[str, Any]] = None ) -> Iterator[_WithSpan]: - # Because OTEL has a default limit of 128 attributes, we split our attributes into - # two tiers, where the addition of "extra_attributes" is deferred until the end - # and only after the "attributes" are added. try: span = self._tracer.start_span( - name=span_name, record_exception=False, set_status_on_exception=False + name=span_name, + record_exception=False, + set_status_on_exception=False, + attributes=attributes, ) except Exception: span = INVALID_SPAN @@ -81,24 +86,21 @@ def __call__( llm_prompt = dict(arguments).pop("prompt", None) llm_invocation_parameters = _get_invocation_parameters(arguments) - span_name = "Completions" with self._start_as_current_span( - span_name, + span_name="Completions", + attributes=dict( + chain( + get_attributes_from_context(), + _get_llm_model(arguments), + _get_llm_provider(), + _get_llm_system(), + _get_llm_span_kind(), + _get_llm_prompts(llm_prompt), + _get_inputs(arguments), + _get_llm_invocation_parameters(llm_invocation_parameters), + ) + ), ) as span: - span.set_attributes(dict(get_attributes_from_context())) - - span.set_attributes( - { - **dict(_get_llm_model(arguments)), - **dict(_get_llm_provider()), - **dict(_get_llm_system()), - OPENINFERENCE_SPAN_KIND: LLM, - LLM_PROMPTS: [llm_prompt], - INPUT_VALUE: safe_json_dumps(arguments), - INPUT_MIME_TYPE: JSON, - LLM_INVOCATION_PARAMETERS: safe_json_dumps(llm_invocation_parameters), - } - ) try: response = wrapped(*args, **kwargs) except Exception as exception: @@ -111,12 +113,7 @@ def __call__( return _Stream(response, span) else: span.set_status(trace_api.StatusCode.OK) - span.set_attributes( - { - OUTPUT_VALUE: response.model_dump_json(), - OUTPUT_MIME_TYPE: JSON, - } - ) + span.set_attributes(dict(_get_outputs(response))) span.finish_tracing() return response @@ -141,24 +138,21 @@ async def __call__( llm_prompt = dict(arguments).pop("prompt", None) invocation_parameters = _get_invocation_parameters(arguments) - span_name = "AsyncCompletions" with self._start_as_current_span( - span_name, + span_name="AsyncCompletions", + attributes=dict( + chain( + get_attributes_from_context(), + _get_llm_model(arguments), + _get_llm_provider(), + _get_llm_system(), + _get_llm_span_kind(), + _get_llm_prompts(llm_prompt), + _get_inputs(arguments), + _get_llm_invocation_parameters(invocation_parameters), + ) + ), ) as span: - span.set_attributes(dict(get_attributes_from_context())) - - span.set_attributes( - { - **dict(_get_llm_model(arguments)), - **dict(_get_llm_provider()), - **dict(_get_llm_system()), - OPENINFERENCE_SPAN_KIND: LLM, - LLM_PROMPTS: [llm_prompt], - INPUT_VALUE: safe_json_dumps(arguments), - INPUT_MIME_TYPE: JSON, - LLM_INVOCATION_PARAMETERS: safe_json_dumps(invocation_parameters), - } - ) try: response = await wrapped(*args, **kwargs) except Exception as exception: @@ -171,12 +165,7 @@ async def __call__( return _Stream(response, span) else: span.set_status(trace_api.StatusCode.OK) - span.set_attributes( - { - OUTPUT_VALUE: response.to_json(indent=None), - OUTPUT_MIME_TYPE: JSON, - } - ) + span.set_attributes(dict(_get_outputs(response))) span.finish_tracing() return response @@ -201,24 +190,21 @@ def __call__( llm_input_messages = dict(arguments).pop("messages", None) invocation_parameters = _get_invocation_parameters(arguments) - span_name = "Messages" with self._start_as_current_span( - span_name, + span_name="Messages", + attributes=dict( + chain( + get_attributes_from_context(), + _get_llm_model(arguments), + _get_llm_provider(), + _get_llm_system(), + _get_llm_span_kind(), + _get_llm_input_messages(llm_input_messages), + _get_llm_invocation_parameters(invocation_parameters), + _get_inputs(arguments), + ) + ), ) as span: - span.set_attributes(dict(get_attributes_from_context())) - - span.set_attributes( - { - **dict(_get_llm_model(arguments)), - **dict(_get_llm_provider()), - **dict(_get_llm_system()), - OPENINFERENCE_SPAN_KIND: LLM, - **dict(_get_input_messages(llm_input_messages)), - LLM_INVOCATION_PARAMETERS: safe_json_dumps(invocation_parameters), - INPUT_VALUE: safe_json_dumps(arguments), - INPUT_MIME_TYPE: JSON, - } - ) try: response = wrapped(*args, **kwargs) except Exception as exception: @@ -231,13 +217,13 @@ def __call__( else: span.set_status(trace_api.StatusCode.OK) span.set_attributes( - { - **dict(_get_output_messages(response)), - LLM_TOKEN_COUNT_PROMPT: response.usage.input_tokens, - LLM_TOKEN_COUNT_COMPLETION: response.usage.output_tokens, - OUTPUT_VALUE: response.model_dump_json(), - OUTPUT_MIME_TYPE: JSON, - } + dict( + chain( + _get_output_messages(response), + _get_llm_token_counts(response.usage), + _get_outputs(response), + ) + ) ) span.finish_tracing() return response @@ -263,24 +249,21 @@ async def __call__( llm_input_messages = dict(arguments).pop("messages", None) invocation_parameters = _get_invocation_parameters(arguments) - span_name = "AsyncMessages" with self._start_as_current_span( - span_name, + span_name="AsyncMessages", + attributes=dict( + chain( + get_attributes_from_context(), + _get_llm_provider(), + _get_llm_system(), + _get_llm_model(arguments), + _get_llm_span_kind(), + _get_llm_input_messages(llm_input_messages), + _get_llm_invocation_parameters(invocation_parameters), + _get_inputs(arguments), + ) + ), ) as span: - span.set_attributes(dict(get_attributes_from_context())) - - span.set_attributes( - { - **dict(_get_llm_provider()), - **dict(_get_llm_system()), - **dict(_get_llm_model(arguments)), - OPENINFERENCE_SPAN_KIND: LLM, - **dict(_get_input_messages(llm_input_messages)), - LLM_INVOCATION_PARAMETERS: safe_json_dumps(invocation_parameters), - INPUT_VALUE: safe_json_dumps(arguments), - INPUT_MIME_TYPE: JSON, - } - ) try: response = await wrapped(*args, **kwargs) except Exception as exception: @@ -293,18 +276,32 @@ async def __call__( else: span.set_status(trace_api.StatusCode.OK) span.set_attributes( - { - **dict(_get_output_messages(response)), - LLM_TOKEN_COUNT_PROMPT: response.usage.input_tokens, - LLM_TOKEN_COUNT_COMPLETION: response.usage.output_tokens, - OUTPUT_VALUE: response.model_dump_json(), - OUTPUT_MIME_TYPE: JSON, - } + dict( + chain( + _get_output_messages(response), + _get_llm_token_counts(response.usage), + _get_outputs(response), + ) + ) ) span.finish_tracing() return response +def _get_inputs(arguments: Mapping[str, Any]) -> Iterator[Tuple[str, Any]]: + yield INPUT_VALUE, safe_json_dumps(arguments) + yield INPUT_MIME_TYPE, JSON + + +def _get_outputs(response: "BaseModel") -> Iterator[Tuple[str, Any]]: + yield OUTPUT_VALUE, response.model_dump_json() + yield OUTPUT_MIME_TYPE, JSON + + +def _get_llm_span_kind() -> Iterator[Tuple[str, Any]]: + yield OPENINFERENCE_SPAN_KIND, LLM + + def _get_llm_provider() -> Iterator[Tuple[str, Any]]: yield LLM_PROVIDER, LLM_PROVIDER_ANTHROPIC @@ -313,12 +310,27 @@ def _get_llm_system() -> Iterator[Tuple[str, Any]]: yield LLM_SYSTEM, LLM_SYSTEM_ANTHROPIC +def _get_llm_token_counts(usage: "Usage") -> Iterator[Tuple[str, Any]]: + yield LLM_TOKEN_COUNT_PROMPT, usage.input_tokens + yield LLM_TOKEN_COUNT_COMPLETION, usage.output_tokens + + def _get_llm_model(arguments: Mapping[str, Any]) -> Iterator[Tuple[str, Any]]: if model_name := arguments.get("model"): yield LLM_MODEL_NAME, model_name -def _get_input_messages(messages: List[Dict[str, str]]) -> Any: +def _get_llm_invocation_parameters( + invocation_parameters: Mapping[str, Any], +) -> Iterator[Tuple[str, Any]]: + yield LLM_INVOCATION_PARAMETERS, safe_json_dumps(invocation_parameters) + + +def _get_llm_prompts(prompt: str) -> Iterator[Tuple[str, Any]]: + yield LLM_PROMPTS, [prompt] + + +def _get_llm_input_messages(messages: List[Dict[str, str]]) -> Any: """ Extracts the messages from the chat response """