Skip to content

Commit

Permalink
Log LLM tool call for streamed response (#545)
Browse files Browse the repository at this point in the history
Co-authored-by: Alex Hall <[email protected]>
  • Loading branch information
jackmpcollins and alexmojaki authored Nov 13, 2024
1 parent ff7211b commit 68fcf5a
Show file tree
Hide file tree
Showing 6 changed files with 403 additions and 38 deletions.
18 changes: 15 additions & 3 deletions logfire/_internal/integrations/llm_providers/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import anthropic
from anthropic.types import Message, TextBlock, TextDelta

from .types import EndpointConfig
from .types import EndpointConfig, StreamState

if TYPE_CHECKING:
from anthropic._models import FinalRequestOptions
Expand All @@ -32,13 +32,12 @@ def get_endpoint_config(options: FinalRequestOptions) -> EndpointConfig:
return EndpointConfig(
message_template='Message with {request_data[model]!r}',
span_data={'request_data': json_data},
content_from_stream=content_from_messages,
stream_state_cls=AnthropicMessageStreamState,
)
else:
return EndpointConfig(
message_template='Anthropic API call to {url!r}',
span_data={'request_data': json_data, 'url': url},
content_from_stream=None,
)


Expand All @@ -50,6 +49,19 @@ def content_from_messages(chunk: anthropic.types.MessageStreamEvent) -> str | No
return None


class AnthropicMessageStreamState(StreamState):
def __init__(self):
self._content: list[str] = []

def record_chunk(self, chunk: anthropic.types.MessageStreamEvent) -> None:
content = content_from_messages(chunk)
if content:
self._content.append(content)

def get_response_data(self) -> Any:
return {'combined_chunk_content': ''.join(self._content), 'chunk_count': len(self._content)}


def on_response(response: ResponseT, span: LogfireSpan) -> ResponseT:
"""Updates the span based on the type of response."""
if isinstance(response, Message): # pragma: no branch
Expand Down
23 changes: 11 additions & 12 deletions logfire/_internal/integrations/llm_providers/llm_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

if TYPE_CHECKING:
from ...main import Logfire, LogfireSpan
from .types import EndpointConfig
from .types import EndpointConfig, StreamState


__all__ = ('instrument_llm_provider',)
Expand Down Expand Up @@ -76,21 +76,21 @@ def _instrumentation_setup(**kwargs: Any) -> Any:
if is_instrumentation_suppressed():
return None, None, kwargs

message_template, span_data, content_from_stream = get_endpoint_config_fn(kwargs['options'])
message_template, span_data, stream_state_cls = get_endpoint_config_fn(kwargs['options'])

span_data['async'] = is_async

stream = kwargs['stream']

if stream and content_from_stream:
if stream and stream_state_cls:
stream_cls = kwargs['stream_cls']
assert stream_cls is not None, 'Expected `stream_cls` when streaming'

if is_async:

class LogfireInstrumentedAsyncStream(stream_cls):
async def __stream__(self) -> AsyncIterator[Any]:
with record_streaming(logfire_llm, span_data, content_from_stream) as record_chunk:
with record_streaming(logfire_llm, span_data, stream_state_cls) as record_chunk:
async for chunk in super().__stream__(): # type: ignore
record_chunk(chunk)
yield chunk
Expand All @@ -100,7 +100,7 @@ async def __stream__(self) -> AsyncIterator[Any]:

class LogfireInstrumentedStream(stream_cls):
def __stream__(self) -> Iterator[Any]:
with record_streaming(logfire_llm, span_data, content_from_stream) as record_chunk:
with record_streaming(logfire_llm, span_data, stream_state_cls) as record_chunk:
for chunk in super().__stream__(): # type: ignore
record_chunk(chunk)
yield chunk
Expand Down Expand Up @@ -174,14 +174,13 @@ def maybe_suppress_instrumentation(suppress: bool) -> Iterator[None]:
def record_streaming(
logire_llm: Logfire,
span_data: dict[str, Any],
content_from_stream: Callable[[Any], str | None],
stream_state_cls: type[StreamState],
):
content: list[str] = []
stream_state = stream_state_cls()

def record_chunk(chunk: Any) -> Any:
chunk_content = content_from_stream(chunk)
if chunk_content:
content.append(chunk_content)
def record_chunk(chunk: Any) -> None:
if chunk:
stream_state.record_chunk(chunk)

timer = logire_llm._config.advanced.ns_timestamp_generator # type: ignore
start = timer()
Expand All @@ -193,5 +192,5 @@ def record_chunk(chunk: Any) -> Any:
'streaming response from {request_data[model]!r} took {duration:.2f}s',
**span_data,
duration=duration,
response_data={'combined_chunk_content': ''.join(content), 'chunk_count': len(content)},
response_data=stream_state.get_response_data(),
)
57 changes: 46 additions & 11 deletions logfire/_internal/integrations/llm_providers/openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, Any, cast

import openai
from openai._legacy_response import LegacyAPIResponse
Expand All @@ -10,7 +10,7 @@
from openai.types.create_embedding_response import CreateEmbeddingResponse
from openai.types.images_response import ImagesResponse

from .types import EndpointConfig
from .types import EndpointConfig, StreamState

if TYPE_CHECKING:
from openai._models import FinalRequestOptions
Expand Down Expand Up @@ -38,31 +38,28 @@ def get_endpoint_config(options: FinalRequestOptions) -> EndpointConfig:
return EndpointConfig(
message_template='Chat Completion with {request_data[model]!r}',
span_data={'request_data': json_data},
content_from_stream=content_from_chat_completions,
stream_state_cls=OpenaiChatCompletionStreamState,
)
elif url == '/completions':
return EndpointConfig(
message_template='Completion with {request_data[model]!r}',
span_data={'request_data': json_data},
content_from_stream=content_from_completions,
stream_state_cls=OpenaiCompletionStreamState,
)
elif url == '/embeddings':
return EndpointConfig(
message_template='Embedding Creation with {request_data[model]!r}',
span_data={'request_data': json_data},
content_from_stream=None,
)
elif url == '/images/generations':
return EndpointConfig(
message_template='Image Generation with {request_data[model]!r}',
span_data={'request_data': json_data},
content_from_stream=None,
)
else:
return EndpointConfig(
message_template='OpenAI API call to {url!r}',
span_data={'request_data': json_data, 'url': url},
content_from_stream=None,
)


Expand All @@ -72,10 +69,48 @@ def content_from_completions(chunk: Completion | None) -> str | None:
return None # pragma: no cover


def content_from_chat_completions(chunk: ChatCompletionChunk | None) -> str | None:
if chunk and chunk.choices:
return chunk.choices[0].delta.content
return None
class OpenaiCompletionStreamState(StreamState):
def __init__(self):
self._content: list[str] = []

def record_chunk(self, chunk: Completion) -> None:
content = content_from_completions(chunk)
if content:
self._content.append(content)

def get_response_data(self) -> Any:
return {'combined_chunk_content': ''.join(self._content), 'chunk_count': len(self._content)}


try:
# ChatCompletionStreamState only exists in openai>=1.40.0
from openai.lib.streaming.chat._completions import ChatCompletionStreamState

class OpenaiChatCompletionStreamState(StreamState):
def __init__(self):
self._stream_state = ChatCompletionStreamState(
# We do not need the response to be parsed into Python objects so can skip
# providing the `response_format` and `input_tools` arguments.
input_tools=openai.NOT_GIVEN,
response_format=openai.NOT_GIVEN,
)

def record_chunk(self, chunk: ChatCompletionChunk) -> None:
self._stream_state.handle_chunk(chunk)

def get_response_data(self) -> Any:
try:
final_completion = self._stream_state.current_completion_snapshot
except AssertionError:
# AssertionError is raised when there is no completion snapshot
# Return empty content to show an empty Assistant response in the UI
return {'combined_chunk_content': '', 'chunk_count': 0}
return {
'message': final_completion.choices[0].message if final_completion.choices else None,
'usage': final_completion.usage,
}
except ImportError: # pragma: no cover
OpenaiChatCompletionStreamState = OpenaiCompletionStreamState # type: ignore


def on_response(response: ResponseT, span: LogfireSpan) -> ResponseT:
Expand Down
17 changes: 15 additions & 2 deletions logfire/_internal/integrations/llm_providers/types.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,26 @@
from __future__ import annotations

from typing import Any, Callable, NamedTuple
from abc import ABC, abstractmethod
from typing import Any, NamedTuple

from typing_extensions import LiteralString


class StreamState(ABC):
"""Keeps track of the state of a streamed response."""

@abstractmethod
def record_chunk(self, chunk: Any) -> None:
"""Update the state based on a chunk from the streamed response."""

@abstractmethod
def get_response_data(self) -> Any:
"""Returns the response data for including in the log."""


class EndpointConfig(NamedTuple):
"""The configuration for the endpoint of a provider based on request url."""

message_template: LiteralString
span_data: dict[str, Any]
content_from_stream: Callable[[Any], str | None] | None
stream_state_cls: type[StreamState] | None = None
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ exclude = [
"out",
"logfire-api",
]
venvPath = ".venv"
venvPath = "."
venv = ".venv"

[tool.pytest.ini_options]
xfail_strict = true
Expand Down
Loading

0 comments on commit 68fcf5a

Please sign in to comment.