diff --git a/deployment/kustomizations/base/cm.yaml b/deployment/kustomizations/base/cm.yaml index 02c6ae447..be4333aed 100644 --- a/deployment/kustomizations/base/cm.yaml +++ b/deployment/kustomizations/base/cm.yaml @@ -171,8 +171,12 @@ data: llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: chart_adjustment llm: litellm_llm.gpt-4o-mini-2024-07-18 + - name: user_guide_assistance + llm: litellm_llm.gpt-4o-mini-2024-07-18 --- settings: + doc_endpoint: https://docs.getwren.ai + is_oss: true column_indexing_batch_size: 50 table_retrieval_size: 10 table_column_retrieval_size: 100 diff --git a/docker/config.example.yaml b/docker/config.example.yaml index 25cadacf1..1b0ec0ace 100644 --- a/docker/config.example.yaml +++ b/docker/config.example.yaml @@ -123,8 +123,13 @@ pipes: llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: chart_adjustment llm: litellm_llm.gpt-4o-mini-2024-07-18 + - name: user_guide_assistance + llm: litellm_llm.gpt-4o-mini-2024-07-18 + --- settings: + doc_endpoint: https://docs.getwren.ai + is_oss: true column_indexing_batch_size: 50 table_retrieval_size: 10 table_column_retrieval_size: 100 diff --git a/wren-ai-service/demo/utils.py b/wren-ai-service/demo/utils.py index 2f01958c7..58f2df875 100644 --- a/wren-ai-service/demo/utils.py +++ b/wren-ai-service/demo/utils.py @@ -619,8 +619,8 @@ def ask(query: str, timezone: str, query_history: Optional[dict] = None): if asks_status == "finished": st.session_state["asks_results_type"] = asks_type - if asks_type == "GENERAL": - display_general_response(query_id) + if asks_type == "GENERAL" or asks_type == "USER_GUIDE": + display_streaming_response(query_id) elif asks_type == "TEXT_TO_SQL": st.session_state["asks_results"] = asks_status_response.json()["response"] else: @@ -632,7 +632,7 @@ def ask(query: str, timezone: str, query_history: Optional[dict] = None): ) -def display_general_response(query_id: str): +def display_streaming_response(query_id: str): url = f"{WREN_AI_SERVICE_BASE_URL}/v1/asks/{query_id}/streaming-result" headers = {"Accept": "text/event-stream"} response = with_requests(url, headers) diff --git a/wren-ai-service/src/config.py b/wren-ai-service/src/config.py index 4ed5ca697..173dd2b74 100644 --- a/wren-ai-service/src/config.py +++ b/wren-ai-service/src/config.py @@ -41,6 +41,10 @@ class Settings(BaseSettings): """, ) + # user guide config + is_oss: bool = Field(default=True) + doc_endpoint: str = Field(default="https://docs.getwren.ai") + # langfuse config # in order to use langfuse, we also need to set the LANGFUSE_SECRET_KEY and LANGFUSE_PUBLIC_KEY in the .env or .env.dev file langfuse_host: str = Field(default="https://cloud.langfuse.com") diff --git a/wren-ai-service/src/globals.py b/wren-ai-service/src/globals.py index 737bfda1c..ad0d2f05e 100644 --- a/wren-ai-service/src/globals.py +++ b/wren-ai-service/src/globals.py @@ -87,6 +87,11 @@ def create_service_container( "data_assistance": generation.DataAssistance( **pipe_components["data_assistance"] ), + "user_guide_assistance": generation.UserGuideAssistance( + **pipe_components["user_guide_assistance"], + is_oss=settings.is_oss, + doc_endpoint=settings.doc_endpoint, + ), "retrieval": retrieval.Retrieval( **pipe_components["db_schema_retrieval"], table_retrieval_size=settings.table_retrieval_size, diff --git a/wren-ai-service/src/pipelines/generation/__init__.py b/wren-ai-service/src/pipelines/generation/__init__.py index 559cec5ac..05fe97461 100644 --- a/wren-ai-service/src/pipelines/generation/__init__.py +++ b/wren-ai-service/src/pipelines/generation/__init__.py @@ -14,6 +14,7 @@ from .sql_generation import SQLGeneration from .sql_regeneration import SQLRegeneration from .sql_summary import SQLSummary +from .user_guide_assistance import UserGuideAssistance __all__ = [ "SQLRegeneration", @@ -32,4 +33,5 @@ "SQLExplanation", "SQLGeneration", "SQLSummary", + "UserGuideAssistance", ] diff --git a/wren-ai-service/src/pipelines/generation/intent_classification.py b/wren-ai-service/src/pipelines/generation/intent_classification.py index f3ce40edf..43c3316ad 100644 --- a/wren-ai-service/src/pipelines/generation/intent_classification.py +++ b/wren-ai-service/src/pipelines/generation/intent_classification.py @@ -24,7 +24,7 @@ ### TASK ### You are a great detective, who is great at intent classification. First, rephrase the user's question to make it more specific, clear and relevant to the database schema before making the intent classification. -Second, you need to use rephrased user's question to classify user's intent based on given database schema to one of three conditions: MISLEADING_QUERY, TEXT_TO_SQL, GENERAL. +Second, you need to use rephrased user's question to classify user's intent based on given database schema to one of four conditions: MISLEADING_QUERY, TEXT_TO_SQL, GENERAL, USER_GUIDE. Also you should provide reasoning for the classification clearly and concisely within 20 words. ### INSTRUCTIONS ### @@ -79,16 +79,25 @@ - Examples: - "What is the dataset about?" - "Tell me more about the database." - - "What can Wren AI do?" - "How can I analyze customer behavior with this data?" - +- USER_GUIDE + - When to Use: + - If the user's question is about Wren AI's features, capabilities, or how to use Wren AI. + - Characteristics: + - The question is about Wren AI's features, capabilities, or how to use Wren AI. + - Examples: + - "What can Wren AI do?" + - "How can I reset project?" + - "How can I delete project?" + - "How can I connect to other databases?" + ### OUTPUT FORMAT ### Please provide your response as a JSON object, structured as follows: { "rephrased_question": "", "reasoning": "", - "results": "MISLEADING_QUERY" | "TEXT_TO_SQL" | "GENERAL" + "results": "MISLEADING_QUERY" | "TEXT_TO_SQL" | "GENERAL" | "USER_GUIDE" } """ @@ -264,7 +273,7 @@ def post_process(classify_intent: dict, construct_db_schemas: list[str]) -> dict class IntentClassificationResult(BaseModel): - results: Literal["MISLEADING_QUERY", "TEXT_TO_SQL", "GENERAL"] + results: Literal["MISLEADING_QUERY", "TEXT_TO_SQL", "GENERAL", "USER_GUIDE"] rephrased_question: str reasoning: str diff --git a/wren-ai-service/src/pipelines/generation/user_guide_assistance.py b/wren-ai-service/src/pipelines/generation/user_guide_assistance.py new file mode 100644 index 000000000..df482fd30 --- /dev/null +++ b/wren-ai-service/src/pipelines/generation/user_guide_assistance.py @@ -0,0 +1,195 @@ +import asyncio +import logging +import sys +from typing import Any, Optional + +import aiohttp +from hamilton import base +from hamilton.async_driver import AsyncDriver +from haystack.components.builders.prompt_builder import PromptBuilder +from langfuse.decorators import observe + +from src.core.pipeline import BasicPipeline +from src.core.provider import LLMProvider +from src.utils import remove_trailing_slash + +logger = logging.getLogger("wren-ai-service") + + +user_guide_assistance_system_prompt = """ +You are a helpful assistant that can help users understand Wren AI. +You are given a user question and a user guide. +You need to understand the user question and the user guide, and then answer the user question. + +### INSTRUCTIONS ### +1. Your answer should be in the same language as the language user provided. +2. You must follow the user guide to answer the user question. +3. If you think you cannot answer the user question given the user guide, you should simply say "I don't know". +4. You should add citations to the user guide(document url) in your answer. +5. You should provide your answer in Markdown format. +""" + +user_guide_assistance_user_prompt_template = """ +User Question: {{query}} +Language: {{language}} +User Guide: +{% for doc in docs %} +- {{doc.path}}: {{doc.content}} +{% endfor %} +Doc Endpoint: {{doc_endpoint}} + +Please think step by step. +""" + + +## Start of Pipeline +@observe +async def fetch_wren_ai_docs(doc_endpoint: str, is_oss: bool) -> str: + doc_endpoint = remove_trailing_slash(doc_endpoint) + api_endpoint = ( + f"{doc_endpoint}/oss/llms.md" if is_oss else f"{doc_endpoint}/cloud/llms.md" + ) + + async with aiohttp.request( + "GET", + api_endpoint, + ) as response: + data = await response.text() + + return data + + +@observe(capture_input=False) +def prompt( + query: str, + language: str, + fetch_wren_ai_docs: str, + doc_endpoint: str, + is_oss: bool, + prompt_builder: PromptBuilder, +) -> dict: + doc_endpoint_base = f"{doc_endpoint}/oss" if is_oss else f"{doc_endpoint}/cloud" + + documents = fetch_wren_ai_docs.split("\n---\n") + docs = [] + for doc in documents: + if doc: + path, content = doc.split("\n") + docs.append( + { + "path": f'{doc_endpoint_base}/{path.replace(".md", "")}', + "content": content, + } + ) + + return prompt_builder.run( + query=query, + language=language, + doc_endpoint=doc_endpoint, + docs=docs, + ) + + +@observe(as_type="generation", capture_input=False) +async def user_guide_assistance(prompt: dict, generator: Any, query_id: str) -> dict: + return await generator(prompt=prompt.get("prompt"), query_id=query_id) + + +## End of Pipeline + + +USER_GUIDE_ASSISTANCE_MODEL_KWARGS = {"response_format": {"type": "text"}} + + +class UserGuideAssistance(BasicPipeline): + def __init__( + self, + llm_provider: LLMProvider, + is_oss: bool, + doc_endpoint: str, + **kwargs, + ): + self._user_queues = {} + self._components = { + "generator": llm_provider.get_generator( + system_prompt=user_guide_assistance_system_prompt, + generation_kwargs=USER_GUIDE_ASSISTANCE_MODEL_KWARGS, + streaming_callback=self._streaming_callback, + ), + "prompt_builder": PromptBuilder( + template=user_guide_assistance_user_prompt_template + ), + } + self._configs = { + "is_oss": is_oss, + "doc_endpoint": doc_endpoint, + } + + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + def _streaming_callback(self, chunk, query_id): + if query_id not in self._user_queues: + self._user_queues[ + query_id + ] = asyncio.Queue() # Create a new queue for the user if it doesn't exist + # Put the chunk content into the user's queue + asyncio.create_task(self._user_queues[query_id].put(chunk.content)) + if chunk.meta.get("finish_reason"): + asyncio.create_task(self._user_queues[query_id].put("")) + + async def get_streaming_results(self, query_id): + async def _get_streaming_results(query_id): + return await self._user_queues[query_id].get() + + if query_id not in self._user_queues: + self._user_queues[ + query_id + ] = asyncio.Queue() # Ensure the user's queue exists + while True: + try: + # Wait for an item from the user's queue + self._streaming_results = await asyncio.wait_for( + _get_streaming_results(query_id), timeout=120 + ) + if ( + self._streaming_results == "" + ): # Check for end-of-stream signal + del self._user_queues[query_id] + break + if self._streaming_results: # Check if there are results to yield + yield self._streaming_results + self._streaming_results = "" # Clear after yielding + except TimeoutError: + break + + @observe(name="User Guide Assistance") + async def run( + self, + query: str, + language: str, + query_id: Optional[str] = None, + ) -> None: + logger.info("User Guide Assistance pipeline is running...") + return await self._pipe.execute( + ["user_guide_assistance"], + inputs={ + "query": query, + "language": language, + "query_id": query_id or "", + **self._components, + **self._configs, + }, + ) + + +if __name__ == "__main__": + from src.pipelines.common import dry_run_pipeline + + dry_run_pipeline( + UserGuideAssistance, + "user_guide_assistance", + query="what can Wren AI do?", + language="en", + ) diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index 6bf26e7fa..b94803832 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -91,7 +91,9 @@ class AskResultResponse(BaseModel): ] rephrased_question: Optional[str] = None intent_reasoning: Optional[str] = None - type: Optional[Literal["MISLEADING_QUERY", "GENERAL", "TEXT_TO_SQL"]] = None + type: Optional[ + Literal["MISLEADING_QUERY", "GENERAL", "TEXT_TO_SQL", "USER_GUIDE"] + ] = None response: Optional[List[AskResult]] = None error: Optional[AskError] = None @@ -107,6 +109,10 @@ def __init__( self._ask_results: Dict[str, AskResultResponse] = TTLCache( maxsize=maxsize, ttl=ttl ) + self._ask_result_type_to_pipeline = { + "GENERAL": self._pipelines["data_assistance"], + "USER_GUIDE": self._pipelines["user_guide_assistance"], + } def _is_stopped(self, query_id: str): if ( @@ -222,6 +228,21 @@ async def ask( ) results["metadata"]["type"] = "GENERAL" return results + elif intent == "USER_GUIDE": + asyncio.create_task( + self._pipelines["user_guide_assistance"].run( + query=ask_request.query, + language=ask_request.configurations.language, + query_id=ask_request.query_id, + ) + ) + + self._ask_results[query_id] = AskResultResponse( + status="finished", + type="USER_GUIDE", + ) + results["metadata"]["type"] = "USER_GUIDE" + return results else: self._ask_results[query_id] = AskResultResponse( status="understanding", @@ -418,13 +439,10 @@ async def get_ask_streaming_result( self, query_id: str, ): - if ( - self._ask_results.get(query_id) - and self._ask_results.get(query_id).type == "GENERAL" + if (ask_result := self._ask_results.get(query_id)) and ( + pipeline := self._ask_result_type_to_pipeline.get(ask_result.type) ): - async for chunk in self._pipelines["data_assistance"].get_streaming_results( - query_id - ): + async for chunk in pipeline.get_streaming_results(query_id): event = SSEEvent( data=SSEEvent.SSEEventMessage(message=chunk), ) diff --git a/wren-ai-service/tests/data/config.test.yaml b/wren-ai-service/tests/data/config.test.yaml index ae366d17f..1e1517bff 100644 --- a/wren-ai-service/tests/data/config.test.yaml +++ b/wren-ai-service/tests/data/config.test.yaml @@ -75,12 +75,18 @@ pipes: - name: relationship_recommendation llm: openai_llm.gpt-4o-mini engine: wren_ui - + - name: user_guide_assistance + llm: openai_llm.gpt-4o-mini + - name: data_assistance + llm: openai_llm.gpt-4o-mini + --- settings: host: 127.0.0.1 port: 5556 column_indexing_batch_size: 50 + doc_endpoint: https://docs.getwren.ai + is_oss: true table_retrieval_size: 10 table_column_retrieval_size: 1000 query_cache_maxsize: 1000 diff --git a/wren-ai-service/tests/pytest/services/test_ask.py b/wren-ai-service/tests/pytest/services/test_ask.py index 1b5229297..6511e84f1 100644 --- a/wren-ai-service/tests/pytest/services/test_ask.py +++ b/wren-ai-service/tests/pytest/services/test_ask.py @@ -40,6 +40,11 @@ def ask_service(): "data_assistance": generation.DataAssistance( **pipe_components["data_assistance"], ), + "user_guide_assistance": generation.UserGuideAssistance( + **pipe_components["user_guide_assistance"], + is_oss=settings.is_oss, + doc_endpoint=settings.doc_endpoint, + ), "retrieval": retrieval.Retrieval( **pipe_components["db_schema_retrieval"], ), @@ -153,13 +158,14 @@ def _ask_service_ttl_mock(query: str): return AskService( { "intent_classification": IntentClassificationMock(), - "data_assistance": "", "retrieval": RetrievalMock( [ f"mock document 1 for {query}", f"mock document 2 for {query}", ] ), + "data_assistance": "", + "user_guide_assistance": "", "historical_question": HistoricalQuestionMock(), "sql_generation": GenerationMock( valid=[{"sql": "select count(*) from books"}], diff --git a/wren-ai-service/tools/config/config.example.yaml b/wren-ai-service/tools/config/config.example.yaml index 0d1d8344e..f85e36b38 100644 --- a/wren-ai-service/tools/config/config.example.yaml +++ b/wren-ai-service/tools/config/config.example.yaml @@ -137,10 +137,14 @@ pipes: llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: sql_executor engine: wren_ui + - name: user_guide_assistance + llm: litellm_llm.gpt-4o-mini-2024-07-18 --- settings: host: 127.0.0.1 port: 5556 + doc_endpoint: https://docs.getwren.ai + is_oss: true column_indexing_batch_size: 50 table_retrieval_size: 10 table_column_retrieval_size: 100 diff --git a/wren-ai-service/tools/config/config.full.yaml b/wren-ai-service/tools/config/config.full.yaml index 644bf8dc3..41a9efdaf 100644 --- a/wren-ai-service/tools/config/config.full.yaml +++ b/wren-ai-service/tools/config/config.full.yaml @@ -156,11 +156,15 @@ pipes: llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: sql_executor engine: wren_ui + - name: user_guide_assistance + llm: litellm_llm.gpt-4o-mini-2024-07-18 --- settings: host: 127.0.0.1 port: 5556 + doc_endpoint: https://docs.getwren.ai + is_oss: true column_indexing_batch_size: 50 table_retrieval_size: 10 table_column_retrieval_size: 100