diff --git a/backend/app/core/graph/checkpoint/utils.py b/backend/app/core/graph/checkpoint/utils.py index beefbf1a..1629c1c4 100644 --- a/backend/app/core/graph/checkpoint/utils.py +++ b/backend/app/core/graph/checkpoint/utils.py @@ -1,6 +1,8 @@ import json +from typing import Any from uuid import uuid4 +from langchain_core.documents import Document from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, ToolMessage from langgraph.checkpoint.base import CheckpointTuple from psycopg import AsyncConnection @@ -59,12 +61,23 @@ def convert_checkpoint_tuple_to_messages( ) ) elif isinstance(message, ToolMessage) and message.name: + documents: list[dict[str, Any]] = [] + if message.name == "KnowledgeBase": + docs: list[Document] = message.artifact + for doc in docs: + documents.append( + { + "score": doc.metadata["score"], + "content": doc.page_content, + } + ) formatted_messages.append( ChatResponse( type="tool", id=message.tool_call_id, name=message.name, tool_output=json.dumps(message.content), + documents=json.dumps(documents), ) ) else: diff --git a/backend/app/core/graph/members.py b/backend/app/core/graph/members.py index 0334973e..1d1e5f52 100644 --- a/backend/app/core/graph/members.py +++ b/backend/app/core/graph/members.py @@ -2,7 +2,6 @@ from typing import Annotated, Any from langchain.chat_models import init_chat_model -from langchain.tools.retriever import create_retriever_tool from langchain_core.messages import AIMessage, AnyMessage from langchain_core.output_parsers.openai_tools import JsonOutputKeyToolsParser from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder @@ -21,6 +20,7 @@ from app.core.graph.rag.qdrant import QdrantStore from app.core.graph.skills import managed_skills from app.core.graph.skills.api_tool import dynamic_api_tool +from app.core.graph.skills.retriever_tool import create_retriever_tool class GraphSkill(BaseModel): @@ -49,9 +49,7 @@ class GraphUpload(BaseModel): @property def tool(self) -> BaseTool: retriever = QdrantStore().retriever(self.owner_id, self.upload_id) - return create_retriever_tool( - retriever, name=self.name, description=self.description - ) + return create_retriever_tool(retriever) class GraphPerson(BaseModel): diff --git a/backend/app/core/graph/messages.py b/backend/app/core/graph/messages.py index c9008d4f..7ae3dd8a 100644 --- a/backend/app/core/graph/messages.py +++ b/backend/app/core/graph/messages.py @@ -75,27 +75,25 @@ def event_to_response(event: StreamEvent) -> ChatResponse | None: elif kind == "on_tool_end": tool_output: ToolMessage | None = event["data"].get("output") tool_name = event["name"] + # If tool is , KnowledgeBase then serialise the documents in artifact + documents: list[dict[str, Any]] = [] + if tool_output and tool_output.name == "KnowledgeBase": + docs: list[Document] = tool_output.artifact + for doc in docs: + documents.append( + { + "score": doc.metadata["score"], + "content": doc.page_content, + } + ) if tool_output: return ChatResponse( type="tool", id=id, name=tool_name, tool_output=json.dumps(tool_output.content), + documents=json.dumps(documents), ) - elif kind == "on_retriever_end": - name = "documents" - docs: list[Document] = event["data"]["output"] - documents: list[dict[str, Any]] = [] - for doc in docs: - documents.append( - { - "score": doc.metadata["score"], - "content": doc.page_content, - } - ) - return ChatResponse( - type="retriever", id=id, name=name, documents=json.dumps(documents) - ) # elif kind == "on_parser_end": # content: str = event["data"]["output"].get("task") # next = event["data"]["output"].get("next") diff --git a/backend/app/core/graph/skills/retriever_tool.py b/backend/app/core/graph/skills/retriever_tool.py new file mode 100644 index 00000000..3dccd0be --- /dev/null +++ b/backend/app/core/graph/skills/retriever_tool.py @@ -0,0 +1,40 @@ +from typing import Annotated, Literal + +from langchain_core.documents import Document +from langchain_core.prompts import BasePromptTemplate, PromptTemplate, format_document +from langchain_core.retrievers import BaseRetriever +from langchain_core.tools import BaseTool + + +class RetrieverTool(BaseTool): + name: str = "KnowledgeBase" + description: str = "Query documents for answers." + response_format: Literal["content", "content_and_artifact"] = "content_and_artifact" + + retriever: BaseRetriever + document_prompt: BasePromptTemplate | PromptTemplate # type: ignore [type-arg] + document_separator: str + + def _run( + self, query: Annotated[str, "query to look up in retriever"] + ) -> tuple[str, list[Document]]: + """Retrieve documents from knowledge base.""" + docs = self.retriever.invoke(query, config={"callbacks": self.callbacks}) + result_string = self.document_separator.join( + [format_document(doc, self.document_prompt) for doc in docs] + ) + return result_string, docs + + +def create_retriever_tool( + retriever: BaseRetriever, + document_prompt: BasePromptTemplate | None = None, # type: ignore [type-arg] + document_separator: str = "\n\n", +) -> BaseTool: + document_prompt = document_prompt or PromptTemplate.from_template("{page_content}") + + return RetrieverTool( + retriever=retriever, + document_prompt=document_prompt, + document_separator=document_separator, + )