HOME


Mini Shell 1.0
DIR: /opt/hc_python/lib64/python3.12/site-packages/sentry_sdk/integrations/google_genai/
Upload File :
Current File : //opt/hc_python/lib64/python3.12/site-packages/sentry_sdk/integrations/google_genai/utils.py
import copy
import inspect
from functools import wraps
from .consts import ORIGIN, TOOL_ATTRIBUTES_MAP, GEN_AI_SYSTEM
from typing import (
    cast,
    TYPE_CHECKING,
    Iterable,
    Any,
    Callable,
    List,
    Optional,
    Union,
    TypedDict,
)

import sentry_sdk
from sentry_sdk.ai.utils import (
    set_data_normalized,
    truncate_and_annotate_messages,
    normalize_message_roles,
)
from sentry_sdk.consts import OP, SPANDATA
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.utils import (
    capture_internal_exceptions,
    event_from_exception,
    safe_serialize,
)
from google.genai.types import GenerateContentConfig

if TYPE_CHECKING:
    from sentry_sdk.tracing import Span
    from google.genai.types import (
        GenerateContentResponse,
        ContentListUnion,
        Tool,
        Model,
    )


class UsageData(TypedDict):
    """Structure for token usage data."""

    input_tokens: int
    input_tokens_cached: int
    output_tokens: int
    output_tokens_reasoning: int
    total_tokens: int


def extract_usage_data(response):
    # type: (Union[GenerateContentResponse, dict[str, Any]]) -> UsageData
    """Extract usage data from response into a structured format.

    Args:
        response: The GenerateContentResponse object or dictionary containing usage metadata

    Returns:
        UsageData: Dictionary with input_tokens, input_tokens_cached,
                   output_tokens, and output_tokens_reasoning fields
    """
    usage_data = UsageData(
        input_tokens=0,
        input_tokens_cached=0,
        output_tokens=0,
        output_tokens_reasoning=0,
        total_tokens=0,
    )

    # Handle dictionary response (from streaming)
    if isinstance(response, dict):
        usage = response.get("usage_metadata", {})
        if not usage:
            return usage_data

        prompt_tokens = usage.get("prompt_token_count", 0) or 0
        tool_use_prompt_tokens = usage.get("tool_use_prompt_token_count", 0) or 0
        usage_data["input_tokens"] = prompt_tokens + tool_use_prompt_tokens

        cached_tokens = usage.get("cached_content_token_count", 0) or 0
        usage_data["input_tokens_cached"] = cached_tokens

        reasoning_tokens = usage.get("thoughts_token_count", 0) or 0
        usage_data["output_tokens_reasoning"] = reasoning_tokens

        candidates_tokens = usage.get("candidates_token_count", 0) or 0
        # python-genai reports output and reasoning tokens separately
        # reasoning should be sub-category of output tokens
        usage_data["output_tokens"] = candidates_tokens + reasoning_tokens

        total_tokens = usage.get("total_token_count", 0) or 0
        usage_data["total_tokens"] = total_tokens

        return usage_data

    if not hasattr(response, "usage_metadata"):
        return usage_data

    usage = response.usage_metadata

    # Input tokens include both prompt and tool use prompt tokens
    prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
    tool_use_prompt_tokens = getattr(usage, "tool_use_prompt_token_count", 0) or 0
    usage_data["input_tokens"] = prompt_tokens + tool_use_prompt_tokens

    # Cached input tokens
    cached_tokens = getattr(usage, "cached_content_token_count", 0) or 0
    usage_data["input_tokens_cached"] = cached_tokens

    # Reasoning tokens
    reasoning_tokens = getattr(usage, "thoughts_token_count", 0) or 0
    usage_data["output_tokens_reasoning"] = reasoning_tokens

    # output_tokens = candidates_tokens + reasoning_tokens
    # google-genai reports output and reasoning tokens separately
    candidates_tokens = getattr(usage, "candidates_token_count", 0) or 0
    usage_data["output_tokens"] = candidates_tokens + reasoning_tokens

    total_tokens = getattr(usage, "total_token_count", 0) or 0
    usage_data["total_tokens"] = total_tokens

    return usage_data


def _capture_exception(exc):
    # type: (Any) -> None
    """Capture exception with Google GenAI mechanism."""
    event, hint = event_from_exception(
        exc,
        client_options=sentry_sdk.get_client().options,
        mechanism={"type": "google_genai", "handled": False},
    )
    sentry_sdk.capture_event(event, hint=hint)


def get_model_name(model):
    # type: (Union[str, Model]) -> str
    """Extract model name from model parameter."""
    if isinstance(model, str):
        return model
    # Handle case where model might be an object with a name attribute
    if hasattr(model, "name"):
        return str(model.name)
    return str(model)


def extract_contents_text(contents):
    # type: (ContentListUnion) -> Optional[str]
    """Extract text from contents parameter which can have various formats."""
    if contents is None:
        return None

    # Simple string case
    if isinstance(contents, str):
        return contents

    # List of contents or parts
    if isinstance(contents, list):
        texts = []
        for item in contents:
            # Recursively extract text from each item
            extracted = extract_contents_text(item)
            if extracted:
                texts.append(extracted)
        return " ".join(texts) if texts else None

    # Dictionary case
    if isinstance(contents, dict):
        if "text" in contents:
            return contents["text"]
        # Try to extract from parts if present in dict
        if "parts" in contents:
            return extract_contents_text(contents["parts"])

    # Content object with parts - recurse into parts
    if getattr(contents, "parts", None):
        return extract_contents_text(contents.parts)

    # Direct text attribute
    if hasattr(contents, "text"):
        return contents.text

    return None


def _format_tools_for_span(tools):
    # type: (Iterable[Tool | Callable[..., Any]]) -> Optional[List[dict[str, Any]]]
    """Format tools parameter for span data."""
    formatted_tools = []
    for tool in tools:
        if callable(tool):
            # Handle callable functions passed directly
            formatted_tools.append(
                {
                    "name": getattr(tool, "__name__", "unknown"),
                    "description": getattr(tool, "__doc__", None),
                }
            )
        elif (
            hasattr(tool, "function_declarations")
            and tool.function_declarations is not None
        ):
            # Tool object with function declarations
            for func_decl in tool.function_declarations:
                formatted_tools.append(
                    {
                        "name": getattr(func_decl, "name", None),
                        "description": getattr(func_decl, "description", None),
                    }
                )
        else:
            # Check for predefined tool attributes - each of these tools
            # is an attribute of the tool object, by default set to None
            for attr_name, description in TOOL_ATTRIBUTES_MAP.items():
                if getattr(tool, attr_name, None):
                    formatted_tools.append(
                        {
                            "name": attr_name,
                            "description": description,
                        }
                    )
                    break

    return formatted_tools if formatted_tools else None


def extract_tool_calls(response):
    # type: (GenerateContentResponse) -> Optional[List[dict[str, Any]]]
    """Extract tool/function calls from response candidates and automatic function calling history."""

    tool_calls = []

    # Extract from candidates, sometimes tool calls are nested under the content.parts object
    if getattr(response, "candidates", []):
        for candidate in response.candidates:
            if not hasattr(candidate, "content") or not getattr(
                candidate.content, "parts", []
            ):
                continue

            for part in candidate.content.parts:
                if getattr(part, "function_call", None):
                    function_call = part.function_call
                    tool_call = {
                        "name": getattr(function_call, "name", None),
                        "type": "function_call",
                    }

                    # Extract arguments if available
                    if getattr(function_call, "args", None):
                        tool_call["arguments"] = safe_serialize(function_call.args)

                    tool_calls.append(tool_call)

    # Extract from automatic_function_calling_history
    # This is the history of tool calls made by the model
    if getattr(response, "automatic_function_calling_history", None):
        for content in response.automatic_function_calling_history:
            if not getattr(content, "parts", None):
                continue

            for part in getattr(content, "parts", []):
                if getattr(part, "function_call", None):
                    function_call = part.function_call
                    tool_call = {
                        "name": getattr(function_call, "name", None),
                        "type": "function_call",
                    }

                    # Extract arguments if available
                    if hasattr(function_call, "args"):
                        tool_call["arguments"] = safe_serialize(function_call.args)

                    tool_calls.append(tool_call)

    return tool_calls if tool_calls else None


def _capture_tool_input(args, kwargs, tool):
    # type: (tuple[Any, ...], dict[str, Any], Tool) -> dict[str, Any]
    """Capture tool input from args and kwargs."""
    tool_input = kwargs.copy() if kwargs else {}

    # If we have positional args, try to map them to the function signature
    if args:
        try:
            sig = inspect.signature(tool)
            param_names = list(sig.parameters.keys())
            for i, arg in enumerate(args):
                if i < len(param_names):
                    tool_input[param_names[i]] = arg
        except Exception:
            # Fallback if we can't get the signature
            tool_input["args"] = args

    return tool_input


def _create_tool_span(tool_name, tool_doc):
    # type: (str, Optional[str]) -> Span
    """Create a span for tool execution."""
    span = sentry_sdk.start_span(
        op=OP.GEN_AI_EXECUTE_TOOL,
        name=f"execute_tool {tool_name}",
        origin=ORIGIN,
    )
    span.set_data(SPANDATA.GEN_AI_TOOL_NAME, tool_name)
    span.set_data(SPANDATA.GEN_AI_TOOL_TYPE, "function")
    if tool_doc:
        span.set_data(SPANDATA.GEN_AI_TOOL_DESCRIPTION, tool_doc)
    return span


def wrapped_tool(tool):
    # type: (Tool | Callable[..., Any]) -> Tool | Callable[..., Any]
    """Wrap a tool to emit execute_tool spans when called."""
    if not callable(tool):
        # Not a callable function, return as-is (predefined tools)
        return tool

    tool_name = getattr(tool, "__name__", "unknown")
    tool_doc = tool.__doc__

    if inspect.iscoroutinefunction(tool):
        # Async function
        @wraps(tool)
        async def async_wrapped(*args, **kwargs):
            # type: (Any, Any) -> Any
            with _create_tool_span(tool_name, tool_doc) as span:
                # Capture tool input
                tool_input = _capture_tool_input(args, kwargs, tool)
                with capture_internal_exceptions():
                    span.set_data(
                        SPANDATA.GEN_AI_TOOL_INPUT, safe_serialize(tool_input)
                    )

                try:
                    result = await tool(*args, **kwargs)

                    # Capture tool output
                    with capture_internal_exceptions():
                        span.set_data(
                            SPANDATA.GEN_AI_TOOL_OUTPUT, safe_serialize(result)
                        )

                    return result
                except Exception as exc:
                    _capture_exception(exc)
                    raise

        return async_wrapped
    else:
        # Sync function
        @wraps(tool)
        def sync_wrapped(*args, **kwargs):
            # type: (Any, Any) -> Any
            with _create_tool_span(tool_name, tool_doc) as span:
                # Capture tool input
                tool_input = _capture_tool_input(args, kwargs, tool)
                with capture_internal_exceptions():
                    span.set_data(
                        SPANDATA.GEN_AI_TOOL_INPUT, safe_serialize(tool_input)
                    )

                try:
                    result = tool(*args, **kwargs)

                    # Capture tool output
                    with capture_internal_exceptions():
                        span.set_data(
                            SPANDATA.GEN_AI_TOOL_OUTPUT, safe_serialize(result)
                        )

                    return result
                except Exception as exc:
                    _capture_exception(exc)
                    raise

        return sync_wrapped


def wrapped_config_with_tools(config):
    # type: (GenerateContentConfig) -> GenerateContentConfig
    """Wrap tools in config to emit execute_tool spans. Tools are sometimes passed directly as
    callable functions as a part of the config object."""

    if not config or not getattr(config, "tools", None):
        return config

    result = copy.copy(config)
    result.tools = [wrapped_tool(tool) for tool in config.tools]

    return result


def _extract_response_text(response):
    # type: (GenerateContentResponse) -> Optional[List[str]]
    """Extract text from response candidates."""

    if not response or not getattr(response, "candidates", []):
        return None

    texts = []
    for candidate in response.candidates:
        if not hasattr(candidate, "content") or not hasattr(candidate.content, "parts"):
            continue

        for part in candidate.content.parts:
            if getattr(part, "text", None):
                texts.append(part.text)

    return texts if texts else None


def extract_finish_reasons(response):
    # type: (GenerateContentResponse) -> Optional[List[str]]
    """Extract finish reasons from response candidates."""
    if not response or not getattr(response, "candidates", []):
        return None

    finish_reasons = []
    for candidate in response.candidates:
        if getattr(candidate, "finish_reason", None):
            # Convert enum value to string if necessary
            reason = str(candidate.finish_reason)
            # Remove enum prefix if present (e.g., "FinishReason.STOP" -> "STOP")
            if "." in reason:
                reason = reason.split(".")[-1]
            finish_reasons.append(reason)

    return finish_reasons if finish_reasons else None


def set_span_data_for_request(span, integration, model, contents, kwargs):
    # type: (Span, Any, str, ContentListUnion, dict[str, Any]) -> None
    """Set span data for the request."""
    span.set_data(SPANDATA.GEN_AI_SYSTEM, GEN_AI_SYSTEM)
    span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model)

    if kwargs.get("stream", False):
        span.set_data(SPANDATA.GEN_AI_RESPONSE_STREAMING, True)

    config = kwargs.get("config")

    if config is None:
        return

    config = cast(GenerateContentConfig, config)

    # Set input messages/prompts if PII is allowed
    if should_send_default_pii() and integration.include_prompts:
        messages = []

        # Add system instruction if present
        if hasattr(config, "system_instruction"):
            system_instruction = config.system_instruction
            if system_instruction:
                system_text = extract_contents_text(system_instruction)
                if system_text:
                    messages.append({"role": "system", "content": system_text})

        # Add user message
        contents_text = extract_contents_text(contents)
        if contents_text:
            messages.append({"role": "user", "content": contents_text})

        if messages:
            normalized_messages = normalize_message_roles(messages)
            scope = sentry_sdk.get_current_scope()
            messages_data = truncate_and_annotate_messages(
                normalized_messages, span, scope
            )
            if messages_data is not None:
                set_data_normalized(
                    span,
                    SPANDATA.GEN_AI_REQUEST_MESSAGES,
                    messages_data,
                    unpack=False,
                )

    # Extract parameters directly from config (not nested under generation_config)
    for param, span_key in [
        ("temperature", SPANDATA.GEN_AI_REQUEST_TEMPERATURE),
        ("top_p", SPANDATA.GEN_AI_REQUEST_TOP_P),
        ("top_k", SPANDATA.GEN_AI_REQUEST_TOP_K),
        ("max_output_tokens", SPANDATA.GEN_AI_REQUEST_MAX_TOKENS),
        ("presence_penalty", SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY),
        ("frequency_penalty", SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY),
        ("seed", SPANDATA.GEN_AI_REQUEST_SEED),
    ]:
        if hasattr(config, param):
            value = getattr(config, param)
            if value is not None:
                span.set_data(span_key, value)

    # Set tools if available
    if hasattr(config, "tools"):
        tools = config.tools
        if tools:
            formatted_tools = _format_tools_for_span(tools)
            if formatted_tools:
                set_data_normalized(
                    span,
                    SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS,
                    formatted_tools,
                    unpack=False,
                )


def set_span_data_for_response(span, integration, response):
    # type: (Span, Any, GenerateContentResponse) -> None
    """Set span data for the response."""
    if not response:
        return

    if should_send_default_pii() and integration.include_prompts:
        response_texts = _extract_response_text(response)
        if response_texts:
            # Format as JSON string array as per documentation
            span.set_data(SPANDATA.GEN_AI_RESPONSE_TEXT, safe_serialize(response_texts))

    tool_calls = extract_tool_calls(response)
    if tool_calls:
        # Tool calls should be JSON serialized
        span.set_data(SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS, safe_serialize(tool_calls))

    finish_reasons = extract_finish_reasons(response)
    if finish_reasons:
        set_data_normalized(
            span, SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS, finish_reasons
        )

    if getattr(response, "response_id", None):
        span.set_data(SPANDATA.GEN_AI_RESPONSE_ID, response.response_id)

    if getattr(response, "model_version", None):
        span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, response.model_version)

    usage_data = extract_usage_data(response)

    if usage_data["input_tokens"]:
        span.set_data(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, usage_data["input_tokens"])

    if usage_data["input_tokens_cached"]:
        span.set_data(
            SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED,
            usage_data["input_tokens_cached"],
        )

    if usage_data["output_tokens"]:
        span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, usage_data["output_tokens"])

    if usage_data["output_tokens_reasoning"]:
        span.set_data(
            SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING,
            usage_data["output_tokens_reasoning"],
        )

    if usage_data["total_tokens"]:
        span.set_data(SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS, usage_data["total_tokens"])


def prepare_generate_content_args(args, kwargs):
    # type: (tuple[Any, ...], dict[str, Any]) -> tuple[Any, Any, str]
    """Extract and prepare common arguments for generate_content methods."""
    model = args[0] if args else kwargs.get("model", "unknown")
    contents = args[1] if len(args) > 1 else kwargs.get("contents")
    model_name = get_model_name(model)

    config = kwargs.get("config")
    wrapped_config = wrapped_config_with_tools(config)
    if wrapped_config is not config:
        kwargs["config"] = wrapped_config

    return model, contents, model_name