Skip to content

Commit

Permalink
feat(wren-ai-service): Embed the SQL pairs in MDL (#1082)
Browse files Browse the repository at this point in the history
  • Loading branch information
paopa authored Jan 10, 2025
1 parent 5a7cd05 commit 186faa7
Show file tree
Hide file tree
Showing 21 changed files with 345 additions and 306 deletions.
5 changes: 2 additions & 3 deletions deployment/kustomizations/base/cm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,9 @@ data:
document_store: qdrant
- name: data_assistance
llm: litellm_llm.gpt-4o-mini-2024-07-18
- name: sql_pairs_preparation
- name: sql_pairs_indexing
document_store: qdrant
embedder: openai_embedder.text-embedding-3-large
llm: litellm_llm.gpt-4o-mini-2024-07-18
- name: sql_pairs_deletion
document_store: qdrant
embedder: openai_embedder.text-embedding-3-large
Expand All @@ -182,4 +181,4 @@ data:
langfuse_host: https://cloud.langfuse.com
langfuse_enable: true
logging_level: DEBUG
development: false
development: false
3 changes: 1 addition & 2 deletions docker/config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,9 @@ pipes:
document_store: qdrant
- name: data_assistance
llm: litellm_llm.gpt-4o-mini-2024-07-18
- name: sql_pairs_preparation
- name: sql_pairs_indexing
document_store: qdrant
embedder: openai_embedder.text-embedding-3-large
llm: litellm_llm.gpt-4o-mini-2024-07-18
- name: sql_pairs_deletion
document_store: qdrant
embedder: openai_embedder.text-embedding-3-large
Expand Down
3 changes: 3 additions & 0 deletions wren-ai-service/Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,6 @@ run-sql mdl_path="" data_source="" sample_dataset="":

report-gen:
poetry run python eval/llm_trace_analysis/report_gen.py

mdl-to-str mdl_path="":
poetry run python tools/mdl_to_str.py -p {{mdl_path}}
2 changes: 2 additions & 0 deletions wren-ai-service/src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class Settings(BaseSettings):
config_path: str = Field(default="config.yaml")
_components: list[dict]

sql_pairs_path: str = Field(default="pairs.json")

def __init__(self):
load_dotenv(".env.dev", override=True)
super().__init__()
Expand Down
8 changes: 6 additions & 2 deletions wren-ai-service/src/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def create_service_container(
"table_description": indexing.TableDescription(
**pipe_components["table_description_indexing"],
),
"sql_pairs": indexing.SqlPairs(
**pipe_components["sql_pairs_indexing"],
sql_pairs_path=settings.sql_pairs_path,
),
},
**query_cache,
),
Expand Down Expand Up @@ -220,8 +224,8 @@ def create_service_container(
),
sql_pairs_preparation_service=SqlPairsPreparationService(
pipelines={
"sql_pairs_preparation": indexing.SqlPairsPreparation(
**pipe_components["sql_pairs_preparation"],
"sql_pairs_preparation": indexing.SqlPairs(
**pipe_components["sql_pairs_indexing"],
),
"sql_pairs_deletion": indexing.SqlPairsDeletion(
**pipe_components["sql_pairs_deletion"],
Expand Down
20 changes: 10 additions & 10 deletions wren-ai-service/src/pipelines/generation/sql_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@
{{ document }}
{% endfor %}
{% if sql_samples %}
### SAMPLES ###
{% for sample in sql_samples %}
Question:
{{sample.question}}
SQL:
{{sample.sql}}
{% endfor %}
{% endif %}
{% if exclude %}
### EXCLUDED STATEMETS ###
Ensure that the following excluded statements are not used in the generated queries to maintain variety and avoid repetition.
Expand All @@ -54,16 +64,6 @@
]
}
{% if sql_samples %}
### SAMPLES ###
{% for sample in sql_samples %}
Summary:
{{sample.summary}}
SQL:
{{sample.sql}}
{% endfor %}
{% endif %}
### QUESTION ###
User's Question: {{ query }}
Current Time: {{ current_time }}
Expand Down
20 changes: 20 additions & 0 deletions wren-ai-service/src/pipelines/generation/utils/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,26 @@ async def _task(result: Dict[str, str]):
PurchaseTimestamp >= DATE_TRUNC('month', CURRENT_DATE - INTERVAL '1 month') AND
PurchaseTimestamp < DATE_TRUNC('month', CURRENT_DATE)
## LESSON 3 ##
Finally, you will learn from the sample SQL queries provided in the input. These samples demonstrate best practices and common patterns for querying this specific database.
For each sample, you should:
1. Study the question that explains what the query aims to accomplish
2. Analyze the SQL implementation to understand:
- Table structures and relationships used
- Specific functions and operators employed
- Query patterns and techniques demonstrated
3. Use these samples as reference patterns when generating similar queries
4. Adapt the techniques shown in the samples to match new query requirements while maintaining consistent style and approach
The samples will help you understand:
- Preferred table join patterns
- Common aggregation methods
- Specific function usage
- Query structure and formatting conventions
When generating new queries, try to follow similar patterns when applicable, while adapting them to the specific requirements of each new query.
Learn about the usage of the schema structures and generate SQL based on them.
"""
Expand Down
12 changes: 7 additions & 5 deletions wren-ai-service/src/pipelines/indexing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,19 @@ def __init__(self, sql_pairs_store: DocumentStore) -> None:
self._sql_pairs_store = sql_pairs_store

@component.output_types()
async def run(self, sql_pair_ids: List[str], id: Optional[str] = None) -> None:
async def run(
self, sql_pair_ids: List[str], project_id: Optional[str] = None
) -> None:
filters = {
"operator": "AND",
"conditions": [
{"field": "sql_pair_id", "operator": "in", "value": sql_pair_ids},
],
}

if id:
if project_id:
filters["conditions"].append(
{"field": "project_id", "operator": "==", "value": id}
{"field": "project_id", "operator": "==", "value": project_id}
)

return await self._sql_pairs_store.delete_documents(filters)
Expand All @@ -112,14 +114,14 @@ async def run(self, sql_pair_ids: List[str], id: Optional[str] = None) -> None:
# Put the pipelines imports here to avoid circular imports and make them accessible directly to the rest of packages
from .db_schema import DBSchema # noqa: E402
from .historical_question import HistoricalQuestion # noqa: E402
from .sql_pairs import SqlPairs # noqa: E402
from .sql_pairs_deletion import SqlPairsDeletion # noqa: E402
from .sql_pairs_preparation import SqlPairsPreparation # noqa: E402
from .table_description import TableDescription # noqa: E402

__all__ = [
"DBSchema",
"TableDescription",
"HistoricalQuestion",
"SqlPairsDeletion",
"SqlPairsPreparation",
"SqlPairs",
]
192 changes: 192 additions & 0 deletions wren-ai-service/src/pipelines/indexing/sql_pairs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import logging
import os
import sys
import uuid
from typing import Any, Dict, List, Optional, Set

import orjson
from hamilton import base
from hamilton.async_driver import AsyncDriver
from haystack import Document, component
from haystack.document_stores.types import DuplicatePolicy
from langfuse.decorators import observe
from pydantic import BaseModel

from src.core.pipeline import BasicPipeline
from src.core.provider import DocumentStoreProvider, EmbedderProvider
from src.pipelines.indexing import AsyncDocumentWriter, SqlPairsCleaner

logger = logging.getLogger("wren-ai-service")


class SqlPair(BaseModel):
id: str
sql: str
question: str


@component
class SqlPairsConverter:
@component.output_types(documents=List[Document])
def run(self, sql_pairs: List[SqlPair], project_id: Optional[str] = ""):
logger.info(f"Project ID: {project_id} Converting SQL pairs to documents...")

addition = {"project_id": project_id} if project_id else {}

return {
"documents": [
Document(
id=str(uuid.uuid4()),
meta={
"sql_pair_id": sql_pair.id,
"sql": sql_pair.sql,
**addition,
},
content=sql_pair.question,
)
for sql_pair in sql_pairs
]
}


## Start of Pipeline
@observe(capture_input=False)
def boilerplates(
mdl_str: str,
) -> Set[str]:
mdl = orjson.loads(mdl_str)

return {
boilerplate.lower()
for model in mdl.get("models", [])
if (boilerplate := model.get("properties", {}).get("boilerplate"))
}


@observe(capture_input=False)
def sql_pairs(
boilerplates: Set[str],
external_pairs: Dict[str, Any],
) -> List[SqlPair]:
return [
SqlPair(
id=pair.get("id"),
question=pair.get("question"),
sql=pair.get("sql"),
)
for boilerplate in boilerplates
if boilerplate in external_pairs
for pair in external_pairs[boilerplate]
]


@observe(capture_input=False)
def to_documents(
sql_pairs: List[SqlPair],
document_converter: SqlPairsConverter,
project_id: Optional[str] = "",
) -> Dict[str, Any]:
return document_converter.run(sql_pairs=sql_pairs, project_id=project_id)


@observe(capture_input=False, capture_output=False)
async def embedding(
to_documents: Dict[str, Any],
embedder: Any,
) -> Dict[str, Any]:
return await embedder.run(documents=to_documents["documents"])


@observe(capture_input=False, capture_output=False)
async def clean(
cleaner: SqlPairsCleaner,
sql_pairs: List[SqlPair],
embedding: Dict[str, Any],
project_id: Optional[str] = "",
) -> Dict[str, Any]:
sql_pair_ids = [sql_pair.id for sql_pair in sql_pairs]
await cleaner.run(sql_pair_ids=sql_pair_ids, project_id=project_id)

return embedding


@observe(capture_input=False)
async def write(
clean: Dict[str, Any],
writer: AsyncDocumentWriter,
) -> None:
return await writer.run(documents=clean["documents"])


## End of Pipeline


def _load_sql_pairs(sql_pairs_path: str) -> Dict[str, Any]:
if not sql_pairs_path:
return {}

if not os.path.exists(sql_pairs_path):
logger.warning(f"SQL pairs file not found: {sql_pairs_path}")
return {}

try:
with open(sql_pairs_path, "r") as file:
return orjson.loads(file.read())
except Exception as e:
logger.error(f"Error loading SQL pairs file: {e}")
return {}


class SqlPairs(BasicPipeline):
def __init__(
self,
embedder_provider: EmbedderProvider,
document_store_provider: DocumentStoreProvider,
sql_pairs_path: Optional[str] = "sql_pairs.json",
**kwargs,
) -> None:
store = document_store_provider.get_store(dataset_name="sql_pairs")

self._components = {
"cleaner": SqlPairsCleaner(store),
"embedder": embedder_provider.get_document_embedder(),
"document_converter": SqlPairsConverter(),
"writer": AsyncDocumentWriter(
document_store=store,
policy=DuplicatePolicy.OVERWRITE,
),
"external_pairs": _load_sql_pairs(sql_pairs_path),
}

super().__init__(
AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult())
)

@observe(name="SQL Pairs Indexing")
async def run(
self,
mdl_str: str,
project_id: Optional[str] = "",
) -> Dict[str, Any]:
logger.info(
f"Project ID: {project_id} SQL Pairs Indexing pipeline is running..."
)

return await self._pipe.execute(
["write"],
inputs={
"mdl_str": mdl_str,
"project_id": project_id,
**self._components,
},
)


if __name__ == "__main__":
from src.pipelines.common import dry_run_pipeline

dry_run_pipeline(
SqlPairs,
"sql_pairs_indexing",
mdl_str='{"models": [{"properties": {"boilerplate": "hubspot"}}]}',
)
3 changes: 2 additions & 1 deletion wren-ai-service/src/pipelines/indexing/sql_pairs_deletion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ async def delete_sql_pairs(
sql_pair_ids: List[str],
id: Optional[str] = None,
) -> None:
return await sql_pairs_cleaner.run(sql_pair_ids=sql_pair_ids, id=id)
return await sql_pairs_cleaner.run(sql_pair_ids=sql_pair_ids, project_id=id)


## End of Pipeline


# TODO: consider removing this pipeline and using the function in the sql_pairs_indexing pipeline instead like other indexing pipelines
class SqlPairsDeletion(BasicPipeline):
def __init__(
self,
Expand Down
Loading

0 comments on commit 186faa7

Please sign in to comment.