Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Log LLM tool call for streamed response #545

Merged
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
a55495d
Add tests for tool call stream w/o snapshot
jackmpcollins Oct 26, 2024
8207557
Fix pyright venv config
jackmpcollins Oct 27, 2024
7e7ca8a
Replace content_from_stream with stream_state_cls
jackmpcollins Oct 27, 2024
6b306bc
Handle null chunk
jackmpcollins Oct 27, 2024
e603489
Add null response for empty choices
jackmpcollins Oct 27, 2024
bf70fe0
exclude_unset in httpx.Response to fix state parsing
jackmpcollins Oct 27, 2024
d27fb6f
Uncomment stream options. Add snapshots
jackmpcollins Oct 27, 2024
a959b31
Improve stream state params comment
jackmpcollins Oct 27, 2024
c8a940c
Use current snapshot to display partial responses
jackmpcollins Oct 27, 2024
3137921
Fix index in streamed text test response
jackmpcollins Oct 27, 2024
57963b6
fix snapshots
jackmpcollins Oct 27, 2024
5f4c326
Update snapshots for anthropic tests
jackmpcollins Oct 27, 2024
c93c5a5
Fix typo: AnthropicMessageStreamState
jackmpcollins Oct 27, 2024
5286f87
Make stream_state_cls required in record_streaming
jackmpcollins Oct 27, 2024
e88f71f
Merge branch 'main' into log-streamed-tool-call-response
jackmpcollins Oct 30, 2024
6119fef
Remove unneeded ...
jackmpcollins Oct 30, 2024
895460f
Merge branch 'main' into log-streamed-tool-call-response
jackmpcollins Nov 10, 2024
7045b77
Fix: assistantassistantassistant role
jackmpcollins Nov 10, 2024
b3b31c6
Add chunk_count to empty response response_data
jackmpcollins Nov 11, 2024
b5b26e8
Fall back to OpenaiCompletionStreamState if import unavailable
jackmpcollins Nov 12, 2024
c3d813e
Merge branch 'main' into log-streamed-tool-call-response
alexmojaki Nov 12, 2024
6f512b8
Merge branch 'main' into log-streamed-tool-call-response
alexmojaki Nov 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions logfire/_internal/integrations/llm_providers/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import anthropic
from anthropic.types import Message, TextBlock, TextDelta

from .types import EndpointConfig
from .types import EndpointConfig, StreamState

if TYPE_CHECKING:
from anthropic._models import FinalRequestOptions
Expand All @@ -32,13 +32,12 @@ def get_endpoint_config(options: FinalRequestOptions) -> EndpointConfig:
return EndpointConfig(
message_template='Message with {request_data[model]!r}',
span_data={'request_data': json_data},
content_from_stream=content_from_messages,
stream_state_cls=AnthropicMessageStreamState,
)
else:
return EndpointConfig(
message_template='Anthropic API call to {url!r}',
span_data={'request_data': json_data, 'url': url},
content_from_stream=None,
)


Expand All @@ -50,6 +49,19 @@ def content_from_messages(chunk: anthropic.types.MessageStreamEvent) -> str | No
return None


class AnthropicMessageStreamState(StreamState):
def __init__(self):
self._content: list[str] = []

def record_chunk(self, chunk: anthropic.types.MessageStreamEvent) -> None:
content = content_from_messages(chunk)
if content:
self._content.append(content)

def get_response_data(self) -> Any:
return {'combined_chunk_content': ''.join(self._content), 'chunk_count': len(self._content)}


def on_response(response: ResponseT, span: LogfireSpan) -> ResponseT:
"""Updates the span based on the type of response."""
if isinstance(response, Message): # pragma: no branch
Expand Down
23 changes: 11 additions & 12 deletions logfire/_internal/integrations/llm_providers/llm_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

if TYPE_CHECKING:
from ...main import Logfire, LogfireSpan
from .types import EndpointConfig
from .types import EndpointConfig, StreamState


__all__ = ('instrument_llm_provider',)
Expand Down Expand Up @@ -76,21 +76,21 @@ def _instrumentation_setup(**kwargs: Any) -> Any:
if is_instrumentation_suppressed():
return None, None, kwargs

message_template, span_data, content_from_stream = get_endpoint_config_fn(kwargs['options'])
message_template, span_data, stream_state_cls = get_endpoint_config_fn(kwargs['options'])

span_data['async'] = is_async

stream = kwargs['stream']

if stream and content_from_stream:
if stream and stream_state_cls:
stream_cls = kwargs['stream_cls']
assert stream_cls is not None, 'Expected `stream_cls` when streaming'

if is_async:

class LogfireInstrumentedAsyncStream(stream_cls):
async def __stream__(self) -> AsyncIterator[Any]:
with record_streaming(logfire_llm, span_data, content_from_stream) as record_chunk:
with record_streaming(logfire_llm, span_data, stream_state_cls) as record_chunk:
async for chunk in super().__stream__(): # type: ignore
record_chunk(chunk)
yield chunk
Expand All @@ -100,7 +100,7 @@ async def __stream__(self) -> AsyncIterator[Any]:

class LogfireInstrumentedStream(stream_cls):
def __stream__(self) -> Iterator[Any]:
with record_streaming(logfire_llm, span_data, content_from_stream) as record_chunk:
with record_streaming(logfire_llm, span_data, stream_state_cls) as record_chunk:
for chunk in super().__stream__(): # type: ignore
record_chunk(chunk)
yield chunk
Expand Down Expand Up @@ -174,14 +174,13 @@ def maybe_suppress_instrumentation(suppress: bool) -> Iterator[None]:
def record_streaming(
logire_llm: Logfire,
span_data: dict[str, Any],
content_from_stream: Callable[[Any], str | None],
stream_state_cls: type[StreamState],
):
content: list[str] = []
stream_state = stream_state_cls()

def record_chunk(chunk: Any) -> Any:
chunk_content = content_from_stream(chunk)
if chunk_content:
content.append(chunk_content)
def record_chunk(chunk: Any) -> None:
if chunk:
stream_state.record_chunk(chunk)

timer = logire_llm._config.advanced.ns_timestamp_generator # type: ignore
start = timer()
Expand All @@ -193,5 +192,5 @@ def record_chunk(chunk: Any) -> Any:
'streaming response from {request_data[model]!r} took {duration:.2f}s',
**span_data,
duration=duration,
response_data={'combined_chunk_content': ''.join(content), 'chunk_count': len(content)},
response_data=stream_state.get_response_data(),
)
52 changes: 41 additions & 11 deletions logfire/_internal/integrations/llm_providers/openai.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from __future__ import annotations

from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, Any, cast

import openai
from openai._legacy_response import LegacyAPIResponse
from openai.lib.streaming.chat._completions import ChatCompletionStreamState
alexmojaki marked this conversation as resolved.
Show resolved Hide resolved
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.completion import Completion
from openai.types.create_embedding_response import CreateEmbeddingResponse
from openai.types.images_response import ImagesResponse

from .types import EndpointConfig
from .types import EndpointConfig, StreamState

if TYPE_CHECKING:
from openai._models import FinalRequestOptions
Expand Down Expand Up @@ -38,31 +39,28 @@ def get_endpoint_config(options: FinalRequestOptions) -> EndpointConfig:
return EndpointConfig(
message_template='Chat Completion with {request_data[model]!r}',
span_data={'request_data': json_data},
content_from_stream=content_from_chat_completions,
stream_state_cls=OpenaiChatCompletionStreamState,
)
elif url == '/completions':
return EndpointConfig(
message_template='Completion with {request_data[model]!r}',
span_data={'request_data': json_data},
content_from_stream=content_from_completions,
stream_state_cls=OpenaiCompletionStreamState,
)
elif url == '/embeddings':
return EndpointConfig(
message_template='Embedding Creation with {request_data[model]!r}',
span_data={'request_data': json_data},
content_from_stream=None,
)
elif url == '/images/generations':
return EndpointConfig(
message_template='Image Generation with {request_data[model]!r}',
span_data={'request_data': json_data},
content_from_stream=None,
)
else:
return EndpointConfig(
message_template='OpenAI API call to {url!r}',
span_data={'request_data': json_data, 'url': url},
content_from_stream=None,
)


Expand All @@ -72,10 +70,42 @@ def content_from_completions(chunk: Completion | None) -> str | None:
return None # pragma: no cover


def content_from_chat_completions(chunk: ChatCompletionChunk | None) -> str | None:
if chunk and chunk.choices:
return chunk.choices[0].delta.content
return None
class OpenaiCompletionStreamState(StreamState):
def __init__(self):
self._content: list[str] = []

def record_chunk(self, chunk: Completion) -> None:
content = content_from_completions(chunk)
if content:
self._content.append(content)

def get_response_data(self) -> Any:
return {'combined_chunk_content': ''.join(self._content), 'chunk_count': len(self._content)}


class OpenaiChatCompletionStreamState(StreamState):
def __init__(self):
self._stream_state = ChatCompletionStreamState(
# We do not need the response to be parsed into Python objects so can skip
# providing the `response_format` and `input_tools` arguments.
input_tools=openai.NOT_GIVEN,
response_format=openai.NOT_GIVEN,
)

def record_chunk(self, chunk: ChatCompletionChunk) -> None:
self._stream_state.handle_chunk(chunk)

def get_response_data(self) -> Any:
try:
final_completion = self._stream_state.current_completion_snapshot
except AssertionError:
# AssertionError is raised when there is no completion snapshot
# Return empty content to show an empty Assistant response in the UI
return {'combined_chunk_content': ''}
return {
'message': final_completion.choices[0].message if final_completion.choices else None,
'usage': final_completion.usage,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I worry about this having a significantly different shape from the other response data dicts.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the same shape as the existing non-streamed chat completion, which is why it displays nicely in the UI.

def on_response(response: ResponseT, span: LogfireSpan) -> ResponseT:
"""Updates the span based on the type of response."""
if isinstance(response, LegacyAPIResponse): # pragma: no cover
on_response(response.parse(), span) # type: ignore
return cast('ResponseT', response)
if isinstance(response, ChatCompletion):
span.set_attribute(
'response_data',
{'message': response.choices[0].message, 'usage': response.usage},
)

I'll admit it's not exactly the same: the message object here is a ParsedChatCompletionMessage, a subclass of the non-streamed chat completion message ChatCompletionMessage. It has the "parsed" and tool call "parsed_arguments" fields added. These "ParsedX" classes are those returned by the client.beta.chat.completions.parse method.

Aside: it would be handy to have the response dicts that have special handling on the frontend be documented or codified (e.g. using TypedDicts), ideally with the order of precedence. For example I found that "combined_chunk_content" takes precedence over "message", which is why I excluded it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, i didn't realise we were already using this shape. yes, creating some types sounds good. @dmontagu any thoughts? should we consider these shapes stable?

}


def on_response(response: ResponseT, span: LogfireSpan) -> ResponseT:
Expand Down
17 changes: 15 additions & 2 deletions logfire/_internal/integrations/llm_providers/types.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,26 @@
from __future__ import annotations

from typing import Any, Callable, NamedTuple
from abc import ABC, abstractmethod
from typing import Any, NamedTuple

from typing_extensions import LiteralString


class StreamState(ABC):
"""Keeps track of the state of a streamed response."""

@abstractmethod
def record_chunk(self, chunk: Any) -> None:
"""Update the state based on a chunk from the streamed response."""

@abstractmethod
def get_response_data(self) -> Any:
"""Returns the response data for including in the log."""


class EndpointConfig(NamedTuple):
"""The configuration for the endpoint of a provider based on request url."""

message_template: LiteralString
span_data: dict[str, Any]
content_from_stream: Callable[[Any], str | None] | None
stream_state_cls: type[StreamState] | None = None
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,8 @@ exclude = [
"out",
"logfire-api",
]
venvPath = ".venv"
venvPath = "."
venv = ".venv"

[tool.pytest.ini_options]
xfail_strict = true
Expand Down
Loading