Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(wren-ai-service): Embed the SQL pairs in MDL #1082

Merged
merged 19 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions deployment/kustomizations/base/cm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,9 @@ data:
document_store: qdrant
- name: data_assistance
llm: litellm_llm.gpt-4o-mini-2024-07-18
- name: sql_pairs_preparation
- name: sql_pairs_indexing
document_store: qdrant
embedder: openai_embedder.text-embedding-3-large
llm: litellm_llm.gpt-4o-mini-2024-07-18
- name: sql_pairs_deletion
document_store: qdrant
embedder: openai_embedder.text-embedding-3-large
Expand All @@ -182,4 +181,4 @@ data:
langfuse_host: https://cloud.langfuse.com
langfuse_enable: true
logging_level: DEBUG
development: false
development: false
3 changes: 1 addition & 2 deletions docker/config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,9 @@ pipes:
document_store: qdrant
- name: data_assistance
llm: litellm_llm.gpt-4o-mini-2024-07-18
- name: sql_pairs_preparation
- name: sql_pairs_indexing
document_store: qdrant
embedder: openai_embedder.text-embedding-3-large
llm: litellm_llm.gpt-4o-mini-2024-07-18
- name: sql_pairs_deletion
document_store: qdrant
embedder: openai_embedder.text-embedding-3-large
Expand Down
3 changes: 3 additions & 0 deletions wren-ai-service/Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,6 @@ run-sql mdl_path="" data_source="" sample_dataset="":

report-gen:
poetry run python eval/llm_trace_analysis/report_gen.py

mdl-to-str mdl_path="":
poetry run python tools/mdl_to_str.py -p {{mdl_path}}
2 changes: 2 additions & 0 deletions wren-ai-service/src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class Settings(BaseSettings):
config_path: str = Field(default="config.yaml")
_components: list[dict]

sql_pairs_path: str = Field(default="pairs.json")

def __init__(self):
load_dotenv(".env.dev", override=True)
super().__init__()
Expand Down
8 changes: 6 additions & 2 deletions wren-ai-service/src/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def create_service_container(
"table_description": indexing.TableDescription(
**pipe_components["table_description_indexing"],
),
"sql_pairs": indexing.SqlPairs(
**pipe_components["sql_pairs_indexing"],
sql_pairs_path=settings.sql_pairs_path,
),
},
**query_cache,
),
Expand Down Expand Up @@ -220,8 +224,8 @@ def create_service_container(
),
sql_pairs_preparation_service=SqlPairsPreparationService(
pipelines={
"sql_pairs_preparation": indexing.SqlPairsPreparation(
**pipe_components["sql_pairs_preparation"],
"sql_pairs_preparation": indexing.SqlPairs(
**pipe_components["sql_pairs_indexing"],
),
"sql_pairs_deletion": indexing.SqlPairsDeletion(
**pipe_components["sql_pairs_deletion"],
Expand Down
20 changes: 10 additions & 10 deletions wren-ai-service/src/pipelines/generation/sql_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@
{{ document }}
{% endfor %}

{% if sql_samples %}
### SAMPLES ###
{% for sample in sql_samples %}
Question:
{{sample.question}}
SQL:
{{sample.sql}}
{% endfor %}
{% endif %}

{% if exclude %}
### EXCLUDED STATEMETS ###
Ensure that the following excluded statements are not used in the generated queries to maintain variety and avoid repetition.
Expand All @@ -54,16 +64,6 @@
]
}

{% if sql_samples %}
### SAMPLES ###
{% for sample in sql_samples %}
Summary:
{{sample.summary}}
SQL:
{{sample.sql}}
{% endfor %}
{% endif %}

### QUESTION ###
User's Question: {{ query }}
Current Time: {{ current_time }}
Expand Down
20 changes: 20 additions & 0 deletions wren-ai-service/src/pipelines/generation/utils/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,26 @@ async def _task(result: Dict[str, str]):
PurchaseTimestamp >= DATE_TRUNC('month', CURRENT_DATE - INTERVAL '1 month') AND
PurchaseTimestamp < DATE_TRUNC('month', CURRENT_DATE)

## LESSON 3 ##
Finally, you will learn from the sample SQL queries provided in the input. These samples demonstrate best practices and common patterns for querying this specific database.

For each sample, you should:
1. Study the question that explains what the query aims to accomplish
2. Analyze the SQL implementation to understand:
- Table structures and relationships used
- Specific functions and operators employed
- Query patterns and techniques demonstrated
3. Use these samples as reference patterns when generating similar queries
4. Adapt the techniques shown in the samples to match new query requirements while maintaining consistent style and approach

The samples will help you understand:
- Preferred table join patterns
- Common aggregation methods
- Specific function usage
- Query structure and formatting conventions

When generating new queries, try to follow similar patterns when applicable, while adapting them to the specific requirements of each new query.

Learn about the usage of the schema structures and generate SQL based on them.

"""
Expand Down
12 changes: 7 additions & 5 deletions wren-ai-service/src/pipelines/indexing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,19 @@ def __init__(self, sql_pairs_store: DocumentStore) -> None:
self._sql_pairs_store = sql_pairs_store

@component.output_types()
async def run(self, sql_pair_ids: List[str], id: Optional[str] = None) -> None:
async def run(
self, sql_pair_ids: List[str], project_id: Optional[str] = None
) -> None:
filters = {
"operator": "AND",
"conditions": [
{"field": "sql_pair_id", "operator": "in", "value": sql_pair_ids},
],
}

if id:
if project_id:
filters["conditions"].append(
{"field": "project_id", "operator": "==", "value": id}
{"field": "project_id", "operator": "==", "value": project_id}
)

return await self._sql_pairs_store.delete_documents(filters)
Expand All @@ -112,14 +114,14 @@ async def run(self, sql_pair_ids: List[str], id: Optional[str] = None) -> None:
# Put the pipelines imports here to avoid circular imports and make them accessible directly to the rest of packages
from .db_schema import DBSchema # noqa: E402
from .historical_question import HistoricalQuestion # noqa: E402
from .sql_pairs import SqlPairs # noqa: E402
from .sql_pairs_deletion import SqlPairsDeletion # noqa: E402
from .sql_pairs_preparation import SqlPairsPreparation # noqa: E402
from .table_description import TableDescription # noqa: E402

__all__ = [
"DBSchema",
"TableDescription",
"HistoricalQuestion",
"SqlPairsDeletion",
"SqlPairsPreparation",
"SqlPairs",
]
192 changes: 192 additions & 0 deletions wren-ai-service/src/pipelines/indexing/sql_pairs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import logging
import os
import sys
import uuid
from typing import Any, Dict, List, Optional, Set

import orjson
from hamilton import base
from hamilton.async_driver import AsyncDriver
from haystack import Document, component
from haystack.document_stores.types import DuplicatePolicy
from langfuse.decorators import observe
from pydantic import BaseModel

from src.core.pipeline import BasicPipeline
from src.core.provider import DocumentStoreProvider, EmbedderProvider
from src.pipelines.indexing import AsyncDocumentWriter, SqlPairsCleaner

logger = logging.getLogger("wren-ai-service")


class SqlPair(BaseModel):
id: str
sql: str
question: str


@component
class SqlPairsConverter:
@component.output_types(documents=List[Document])
def run(self, sql_pairs: List[SqlPair], project_id: Optional[str] = ""):
logger.info(f"Project ID: {project_id} Converting SQL pairs to documents...")

addition = {"project_id": project_id} if project_id else {}

return {
"documents": [
Document(
id=str(uuid.uuid4()),
meta={
"sql_pair_id": sql_pair.id,
"sql": sql_pair.sql,
**addition,
},
content=sql_pair.question,
)
for sql_pair in sql_pairs
]
}


## Start of Pipeline
@observe(capture_input=False)
def boilerplates(
mdl_str: str,
) -> Set[str]:
mdl = orjson.loads(mdl_str)

return {
boilerplate.lower()
for model in mdl.get("models", [])
if (boilerplate := model.get("properties", {}).get("boilerplate"))
}


@observe(capture_input=False)
def sql_pairs(
boilerplates: Set[str],
external_pairs: Dict[str, Any],
) -> List[SqlPair]:
return [
SqlPair(
id=pair.get("id"),
question=pair.get("question"),
sql=pair.get("sql"),
)
for boilerplate in boilerplates
Comment on lines +71 to +77
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ensure id field is not missing in SqlPair instances

In the sql_pairs function, the id field is retrieved using pair.get("id"), which may return None if the key is missing. Since SqlPair requires id to be a string, a missing id could lead to validation errors or unintended behavior.

Apply this diff to enforce the presence of the id field:

def sql_pairs(
    boilerplates: Set[str],
    external_pairs: Dict[str, Any],
) -> List[SqlPair]:
    return [
        SqlPair(
-           id=pair.get("id"),
+           id=pair["id"],
            question=pair.get("question"),
            sql=pair.get("sql"),
        )
        for boilerplate in boilerplates
        if boilerplate in external_pairs
        for pair in external_pairs[boilerplate]
    ]

Alternatively, provide a default id or handle cases where id might be missing to prevent potential errors.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
return [
SqlPair(
id=pair.get("id"),
question=pair.get("question"),
sql=pair.get("sql"),
)
for boilerplate in boilerplates
return [
SqlPair(
id=pair["id"],
question=pair.get("question"),
sql=pair.get("sql"),
)
for boilerplate in boilerplates

if boilerplate in external_pairs
for pair in external_pairs[boilerplate]
]


@observe(capture_input=False)
def to_documents(
sql_pairs: List[SqlPair],
document_converter: SqlPairsConverter,
project_id: Optional[str] = "",
) -> Dict[str, Any]:
return document_converter.run(sql_pairs=sql_pairs, project_id=project_id)


@observe(capture_input=False, capture_output=False)
async def embedding(
to_documents: Dict[str, Any],
embedder: Any,
) -> Dict[str, Any]:
return await embedder.run(documents=to_documents["documents"])


@observe(capture_input=False, capture_output=False)
async def clean(
cleaner: SqlPairsCleaner,
sql_pairs: List[SqlPair],
embedding: Dict[str, Any],
project_id: Optional[str] = "",
) -> Dict[str, Any]:
sql_pair_ids = [sql_pair.id for sql_pair in sql_pairs]
await cleaner.run(sql_pair_ids=sql_pair_ids, project_id=project_id)

return embedding


@observe(capture_input=False)
async def write(
clean: Dict[str, Any],
writer: AsyncDocumentWriter,
) -> None:
return await writer.run(documents=clean["documents"])


## End of Pipeline


def _load_sql_pairs(sql_pairs_path: str) -> Dict[str, Any]:
if not sql_pairs_path:
return {}

if not os.path.exists(sql_pairs_path):
logger.warning(f"SQL pairs file not found: {sql_pairs_path}")
return {}

try:
with open(sql_pairs_path, "r") as file:
return orjson.loads(file.read())
except Exception as e:
logger.error(f"Error loading SQL pairs file: {e}")
return {}


class SqlPairs(BasicPipeline):
def __init__(
self,
embedder_provider: EmbedderProvider,
document_store_provider: DocumentStoreProvider,
sql_pairs_path: Optional[str] = "sql_pairs.json",
**kwargs,
) -> None:
store = document_store_provider.get_store(dataset_name="sql_pairs")

self._components = {
"cleaner": SqlPairsCleaner(store),
"embedder": embedder_provider.get_document_embedder(),
"document_converter": SqlPairsConverter(),
"writer": AsyncDocumentWriter(
document_store=store,
policy=DuplicatePolicy.OVERWRITE,
),
"external_pairs": _load_sql_pairs(sql_pairs_path),
}

super().__init__(
AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult())
)

@observe(name="SQL Pairs Indexing")
async def run(
self,
mdl_str: str,
project_id: Optional[str] = "",
) -> Dict[str, Any]:
logger.info(
f"Project ID: {project_id} SQL Pairs Indexing pipeline is running..."
)

return await self._pipe.execute(
["write"],
inputs={
"mdl_str": mdl_str,
"project_id": project_id,
**self._components,
},
)


if __name__ == "__main__":
from src.pipelines.common import dry_run_pipeline

dry_run_pipeline(
SqlPairs,
"sql_pairs_indexing",
mdl_str='{"models": [{"properties": {"boilerplate": "hubspot"}}]}',
)
3 changes: 2 additions & 1 deletion wren-ai-service/src/pipelines/indexing/sql_pairs_deletion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ async def delete_sql_pairs(
sql_pair_ids: List[str],
id: Optional[str] = None,
) -> None:
return await sql_pairs_cleaner.run(sql_pair_ids=sql_pair_ids, id=id)
return await sql_pairs_cleaner.run(sql_pair_ids=sql_pair_ids, project_id=id)


## End of Pipeline


# TODO: consider removing this pipeline and using the function in the sql_pairs_indexing pipeline instead like other indexing pipelines
class SqlPairsDeletion(BasicPipeline):
def __init__(
self,
Expand Down
Loading
Loading