Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(wren-ai-service): user guide #1015

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
4 changes: 4 additions & 0 deletions deployment/kustomizations/base/cm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,12 @@ data:
llm: litellm_llm.gpt-4o-mini-2024-07-18
- name: chart_adjustment
llm: litellm_llm.gpt-4o-mini-2024-07-18
- name: user_guide_assistance
llm: litellm_llm.gpt-4o-mini-2024-07-18
---
settings:
doc_endpoint: https://docs.getwren.ai
is_oss: true
column_indexing_batch_size: 50
table_retrieval_size: 10
table_column_retrieval_size: 100
Expand Down
5 changes: 5 additions & 0 deletions docker/config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,13 @@ pipes:
llm: litellm_llm.gpt-4o-mini-2024-07-18
- name: chart_adjustment
llm: litellm_llm.gpt-4o-mini-2024-07-18
- name: user_guide_assistance
llm: litellm_llm.gpt-4o-mini-2024-07-18

---
settings:
doc_endpoint: https://docs.getwren.ai
is_oss: true
column_indexing_batch_size: 50
table_retrieval_size: 10
table_column_retrieval_size: 100
Expand Down
6 changes: 3 additions & 3 deletions wren-ai-service/demo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,8 +619,8 @@ def ask(query: str, timezone: str, query_history: Optional[dict] = None):

if asks_status == "finished":
st.session_state["asks_results_type"] = asks_type
if asks_type == "GENERAL":
display_general_response(query_id)
if asks_type == "GENERAL" or asks_type == "USER_GUIDE":
display_streaming_response(query_id)
elif asks_type == "TEXT_TO_SQL":
st.session_state["asks_results"] = asks_status_response.json()["response"]
else:
Expand All @@ -632,7 +632,7 @@ def ask(query: str, timezone: str, query_history: Optional[dict] = None):
)


def display_general_response(query_id: str):
def display_streaming_response(query_id: str):
url = f"{WREN_AI_SERVICE_BASE_URL}/v1/asks/{query_id}/streaming-result"
headers = {"Accept": "text/event-stream"}
response = with_requests(url, headers)
Expand Down
4 changes: 4 additions & 0 deletions wren-ai-service/src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ class Settings(BaseSettings):
""",
)

# user guide config
is_oss: bool = Field(default=True)
doc_endpoint: str = Field(default="https://docs.getwren.ai")

# langfuse config
# in order to use langfuse, we also need to set the LANGFUSE_SECRET_KEY and LANGFUSE_PUBLIC_KEY in the .env or .env.dev file
langfuse_host: str = Field(default="https://cloud.langfuse.com")
Expand Down
5 changes: 5 additions & 0 deletions wren-ai-service/src/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ def create_service_container(
"data_assistance": generation.DataAssistance(
**pipe_components["data_assistance"]
),
"user_guide_assistance": generation.UserGuideAssistance(
**pipe_components["user_guide_assistance"],
is_oss=settings.is_oss,
doc_endpoint=settings.doc_endpoint,
),
"retrieval": retrieval.Retrieval(
**pipe_components["db_schema_retrieval"],
table_retrieval_size=settings.table_retrieval_size,
Expand Down
2 changes: 2 additions & 0 deletions wren-ai-service/src/pipelines/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .sql_generation import SQLGeneration
from .sql_regeneration import SQLRegeneration
from .sql_summary import SQLSummary
from .user_guide_assistance import UserGuideAssistance

__all__ = [
"SQLRegeneration",
Expand All @@ -32,4 +33,5 @@
"SQLExplanation",
"SQLGeneration",
"SQLSummary",
"UserGuideAssistance",
]
19 changes: 14 additions & 5 deletions wren-ai-service/src/pipelines/generation/intent_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
### TASK ###
You are a great detective, who is great at intent classification.
First, rephrase the user's question to make it more specific, clear and relevant to the database schema before making the intent classification.
Second, you need to use rephrased user's question to classify user's intent based on given database schema to one of three conditions: MISLEADING_QUERY, TEXT_TO_SQL, GENERAL.
Second, you need to use rephrased user's question to classify user's intent based on given database schema to one of four conditions: MISLEADING_QUERY, TEXT_TO_SQL, GENERAL, USER_GUIDE.
Also you should provide reasoning for the classification clearly and concisely within 20 words.

### INSTRUCTIONS ###
Expand Down Expand Up @@ -79,16 +79,25 @@
- Examples:
- "What is the dataset about?"
- "Tell me more about the database."
- "What can Wren AI do?"
- "How can I analyze customer behavior with this data?"

- USER_GUIDE
- When to Use:
- If the user's question is about Wren AI's features, capabilities, or how to use Wren AI.
- Characteristics:
- The question is about Wren AI's features, capabilities, or how to use Wren AI.
- Examples:
- "What can Wren AI do?"
- "How can I reset project?"
- "How can I delete project?"
- "How can I connect to other databases?"

### OUTPUT FORMAT ###
Please provide your response as a JSON object, structured as follows:

{
"rephrased_question": "<REPHRASED_USER_QUESTION_IN_STRING_FORMAT>",
"reasoning": "<CHAIN_OF_THOUGHT_REASONING_BASED_ON_REPHRASED_USER_QUESTION_IN_STRING_FORMAT>",
"results": "MISLEADING_QUERY" | "TEXT_TO_SQL" | "GENERAL"
"results": "MISLEADING_QUERY" | "TEXT_TO_SQL" | "GENERAL" | "USER_GUIDE"
}
"""

Expand Down Expand Up @@ -264,7 +273,7 @@ def post_process(classify_intent: dict, construct_db_schemas: list[str]) -> dict


class IntentClassificationResult(BaseModel):
results: Literal["MISLEADING_QUERY", "TEXT_TO_SQL", "GENERAL"]
results: Literal["MISLEADING_QUERY", "TEXT_TO_SQL", "GENERAL", "USER_GUIDE"]
rephrased_question: str
reasoning: str

Expand Down
195 changes: 195 additions & 0 deletions wren-ai-service/src/pipelines/generation/user_guide_assistance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
import asyncio
import logging
import sys
from typing import Any, Optional

import aiohttp
from hamilton import base
from hamilton.async_driver import AsyncDriver
from haystack.components.builders.prompt_builder import PromptBuilder
from langfuse.decorators import observe

from src.core.pipeline import BasicPipeline
from src.core.provider import LLMProvider
from src.utils import remove_trailing_slash

logger = logging.getLogger("wren-ai-service")


user_guide_assistance_system_prompt = """
You are a helpful assistant that can help users understand Wren AI.
You are given a user question and a user guide.
You need to understand the user question and the user guide, and then answer the user question.

### INSTRUCTIONS ###
1. Your answer should be in the same language as the language user provided.
2. You must follow the user guide to answer the user question.
3. If you think you cannot answer the user question given the user guide, you should simply say "I don't know".
4. You should add citations to the user guide(document url) in your answer.
5. You should provide your answer in Markdown format.
"""

user_guide_assistance_user_prompt_template = """
User Question: {{query}}
Language: {{language}}
User Guide:
{% for doc in docs %}
- {{doc.path}}: {{doc.content}}
{% endfor %}
Doc Endpoint: {{doc_endpoint}}

Please think step by step.
"""


## Start of Pipeline
@observe
async def fetch_wren_ai_docs(doc_endpoint: str, is_oss: bool) -> str:
doc_endpoint = remove_trailing_slash(doc_endpoint)
api_endpoint = (
f"{doc_endpoint}/oss/llms.md" if is_oss else f"{doc_endpoint}/cloud/llms.md"
)

async with aiohttp.request(
"GET",
api_endpoint,
) as response:
data = await response.text()

return data


@observe(capture_input=False)
def prompt(
query: str,
language: str,
fetch_wren_ai_docs: str,
doc_endpoint: str,
is_oss: bool,
prompt_builder: PromptBuilder,
) -> dict:
doc_endpoint_base = f"{doc_endpoint}/oss" if is_oss else f"{doc_endpoint}/cloud"

documents = fetch_wren_ai_docs.split("\n---\n")
docs = []
for doc in documents:
if doc:
path, content = doc.split("\n")
docs.append(
{
"path": f'{doc_endpoint_base}/{path.replace(".md", "")}',
"content": content,
}
)

return prompt_builder.run(
query=query,
language=language,
doc_endpoint=doc_endpoint,
docs=docs,
)


@observe(as_type="generation", capture_input=False)
async def user_guide_assistance(prompt: dict, generator: Any, query_id: str) -> dict:
return await generator(prompt=prompt.get("prompt"), query_id=query_id)


## End of Pipeline


USER_GUIDE_ASSISTANCE_MODEL_KWARGS = {"response_format": {"type": "text"}}


class UserGuideAssistance(BasicPipeline):
def __init__(
self,
llm_provider: LLMProvider,
is_oss: bool,
doc_endpoint: str,
**kwargs,
):
self._user_queues = {}
self._components = {
"generator": llm_provider.get_generator(
system_prompt=user_guide_assistance_system_prompt,
generation_kwargs=USER_GUIDE_ASSISTANCE_MODEL_KWARGS,
streaming_callback=self._streaming_callback,
),
"prompt_builder": PromptBuilder(
template=user_guide_assistance_user_prompt_template
),
}
self._configs = {
"is_oss": is_oss,
"doc_endpoint": doc_endpoint,
}

super().__init__(
AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult())
)

def _streaming_callback(self, chunk, query_id):
if query_id not in self._user_queues:
self._user_queues[
query_id
] = asyncio.Queue() # Create a new queue for the user if it doesn't exist
# Put the chunk content into the user's queue
asyncio.create_task(self._user_queues[query_id].put(chunk.content))
if chunk.meta.get("finish_reason"):
asyncio.create_task(self._user_queues[query_id].put("<DONE>"))

async def get_streaming_results(self, query_id):
async def _get_streaming_results(query_id):
return await self._user_queues[query_id].get()

if query_id not in self._user_queues:
self._user_queues[
query_id
] = asyncio.Queue() # Ensure the user's queue exists
while True:
try:
# Wait for an item from the user's queue
self._streaming_results = await asyncio.wait_for(
_get_streaming_results(query_id), timeout=120
)
if (
self._streaming_results == "<DONE>"
): # Check for end-of-stream signal
del self._user_queues[query_id]
break
if self._streaming_results: # Check if there are results to yield
yield self._streaming_results
self._streaming_results = "" # Clear after yielding
except TimeoutError:
break

@observe(name="User Guide Assistance")
async def run(
self,
query: str,
language: str,
query_id: Optional[str] = None,
) -> None:
logger.info("User Guide Assistance pipeline is running...")
return await self._pipe.execute(
["user_guide_assistance"],
inputs={
"query": query,
"language": language,
"query_id": query_id or "",
**self._components,
**self._configs,
},
)


if __name__ == "__main__":
from src.pipelines.common import dry_run_pipeline

dry_run_pipeline(
UserGuideAssistance,
"user_guide_assistance",
query="what can Wren AI do?",
language="en",
)
32 changes: 25 additions & 7 deletions wren-ai-service/src/web/v1/services/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ class AskResultResponse(BaseModel):
]
rephrased_question: Optional[str] = None
intent_reasoning: Optional[str] = None
type: Optional[Literal["MISLEADING_QUERY", "GENERAL", "TEXT_TO_SQL"]] = None
type: Optional[
Literal["MISLEADING_QUERY", "GENERAL", "TEXT_TO_SQL", "USER_GUIDE"]
] = None
response: Optional[List[AskResult]] = None
error: Optional[AskError] = None

Expand All @@ -107,6 +109,10 @@ def __init__(
self._ask_results: Dict[str, AskResultResponse] = TTLCache(
maxsize=maxsize, ttl=ttl
)
self._ask_result_type_to_pipeline = {
"GENERAL": self._pipelines["data_assistance"],
"USER_GUIDE": self._pipelines["user_guide_assistance"],
}

def _is_stopped(self, query_id: str):
if (
Expand Down Expand Up @@ -222,6 +228,21 @@ async def ask(
)
results["metadata"]["type"] = "GENERAL"
return results
elif intent == "USER_GUIDE":
asyncio.create_task(
self._pipelines["user_guide_assistance"].run(
query=ask_request.query,
language=ask_request.configurations.language,
query_id=ask_request.query_id,
)
)

self._ask_results[query_id] = AskResultResponse(
status="finished",
type="USER_GUIDE",
)
results["metadata"]["type"] = "USER_GUIDE"
return results
else:
self._ask_results[query_id] = AskResultResponse(
status="understanding",
Expand Down Expand Up @@ -418,13 +439,10 @@ async def get_ask_streaming_result(
self,
query_id: str,
):
if (
self._ask_results.get(query_id)
and self._ask_results.get(query_id).type == "GENERAL"
if (ask_result := self._ask_results.get(query_id)) and (
pipeline := self._ask_result_type_to_pipeline.get(ask_result.type)
):
async for chunk in self._pipelines["data_assistance"].get_streaming_results(
query_id
):
async for chunk in pipeline.get_streaming_results(query_id):
event = SSEEvent(
data=SSEEvent.SSEEventMessage(message=chunk),
)
Expand Down
Loading
Loading