Skip to content

Commit

Permalink
Save fetched documents in history (#103)
Browse files Browse the repository at this point in the history
  • Loading branch information
StreetLamb authored Aug 4, 2024
1 parent 841535a commit e84bc5a
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 18 deletions.
13 changes: 13 additions & 0 deletions backend/app/core/graph/checkpoint/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
from typing import Any
from uuid import uuid4

from langchain_core.documents import Document
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, ToolMessage
from langgraph.checkpoint.base import CheckpointTuple
from psycopg import AsyncConnection
Expand Down Expand Up @@ -59,12 +61,23 @@ def convert_checkpoint_tuple_to_messages(
)
)
elif isinstance(message, ToolMessage) and message.name:
documents: list[dict[str, Any]] = []
if message.name == "KnowledgeBase":
docs: list[Document] = message.artifact
for doc in docs:
documents.append(
{
"score": doc.metadata["score"],
"content": doc.page_content,
}
)
formatted_messages.append(
ChatResponse(
type="tool",
id=message.tool_call_id,
name=message.name,
tool_output=json.dumps(message.content),
documents=json.dumps(documents),
)
)
else:
Expand Down
6 changes: 2 additions & 4 deletions backend/app/core/graph/members.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Annotated, Any

from langchain.chat_models import init_chat_model
from langchain.tools.retriever import create_retriever_tool
from langchain_core.messages import AIMessage, AnyMessage
from langchain_core.output_parsers.openai_tools import JsonOutputKeyToolsParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
Expand All @@ -21,6 +20,7 @@
from app.core.graph.rag.qdrant import QdrantStore
from app.core.graph.skills import managed_skills
from app.core.graph.skills.api_tool import dynamic_api_tool
from app.core.graph.skills.retriever_tool import create_retriever_tool


class GraphSkill(BaseModel):
Expand Down Expand Up @@ -49,9 +49,7 @@ class GraphUpload(BaseModel):
@property
def tool(self) -> BaseTool:
retriever = QdrantStore().retriever(self.owner_id, self.upload_id)
return create_retriever_tool(
retriever, name=self.name, description=self.description
)
return create_retriever_tool(retriever)


class GraphPerson(BaseModel):
Expand Down
26 changes: 12 additions & 14 deletions backend/app/core/graph/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,27 +75,25 @@ def event_to_response(event: StreamEvent) -> ChatResponse | None:
elif kind == "on_tool_end":
tool_output: ToolMessage | None = event["data"].get("output")
tool_name = event["name"]
# If tool is , KnowledgeBase then serialise the documents in artifact
documents: list[dict[str, Any]] = []
if tool_output and tool_output.name == "KnowledgeBase":
docs: list[Document] = tool_output.artifact
for doc in docs:
documents.append(
{
"score": doc.metadata["score"],
"content": doc.page_content,
}
)
if tool_output:
return ChatResponse(
type="tool",
id=id,
name=tool_name,
tool_output=json.dumps(tool_output.content),
documents=json.dumps(documents),
)
elif kind == "on_retriever_end":
name = "documents"
docs: list[Document] = event["data"]["output"]
documents: list[dict[str, Any]] = []
for doc in docs:
documents.append(
{
"score": doc.metadata["score"],
"content": doc.page_content,
}
)
return ChatResponse(
type="retriever", id=id, name=name, documents=json.dumps(documents)
)
# elif kind == "on_parser_end":
# content: str = event["data"]["output"].get("task")
# next = event["data"]["output"].get("next")
Expand Down
40 changes: 40 additions & 0 deletions backend/app/core/graph/skills/retriever_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Annotated, Literal

from langchain_core.documents import Document
from langchain_core.prompts import BasePromptTemplate, PromptTemplate, format_document
from langchain_core.retrievers import BaseRetriever
from langchain_core.tools import BaseTool


class RetrieverTool(BaseTool):
name: str = "KnowledgeBase"
description: str = "Query documents for answers."
response_format: Literal["content", "content_and_artifact"] = "content_and_artifact"

retriever: BaseRetriever
document_prompt: BasePromptTemplate | PromptTemplate # type: ignore [type-arg]
document_separator: str

def _run(
self, query: Annotated[str, "query to look up in retriever"]
) -> tuple[str, list[Document]]:
"""Retrieve documents from knowledge base."""
docs = self.retriever.invoke(query, config={"callbacks": self.callbacks})
result_string = self.document_separator.join(
[format_document(doc, self.document_prompt) for doc in docs]
)
return result_string, docs


def create_retriever_tool(
retriever: BaseRetriever,
document_prompt: BasePromptTemplate | None = None, # type: ignore [type-arg]
document_separator: str = "\n\n",
) -> BaseTool:
document_prompt = document_prompt or PromptTemplate.from_template("{page_content}")

return RetrieverTool(
retriever=retriever,
document_prompt=document_prompt,
document_separator=document_separator,
)

0 comments on commit e84bc5a

Please sign in to comment.