-
Notifications
You must be signed in to change notification settings - Fork 285
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(wren-ai-service): Embed the SQL pairs in MDL #1082
Merged
+345
−306
Merged
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
8cf291e
chore: refactor partial code
paopa f40e0bc
feat: pushdown the sql pair class to the core
paopa 5340ee7
feat: tool to convert mdl to str for http request testing
paopa ea99262
feat: index sql pairs if MDL includes the key
paopa e4df29b
fix: failed test cases
paopa b7c759c
fix: correct the behavior for document id
paopa 2819a4f
feat: prompt enhancement for SQL pairs
paopa 7d2b088
feat: use question instead of sql summary
paopa 6d56393
feat: sql pairs from external file
paopa 7399656
add the todo for sql pair preparation endpoint
paopa 8cbb369
chore: comment for deletion pipe
paopa 05f1e8f
fix: test cases
paopa 5d87164
chore: remove unused attribute in doc
paopa eb4619c
chore: skip some test cases for sql pair service
paopa 513350d
fix: test cases for sql pair and deletion
paopa 923f5ce
chore: mount sql pair file for each deployment
paopa 254c180
chore: deal import, type and error handling suggestion
paopa c79f907
chore: rename test case, and correct pipeline type
paopa 81f13bf
chore: remove empty pairs file from each deployment
paopa File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
import logging | ||
import os | ||
import sys | ||
import uuid | ||
from typing import Any, Dict, List, Optional, Set | ||
|
||
import orjson | ||
from hamilton import base | ||
from hamilton.async_driver import AsyncDriver | ||
from haystack import Document, component | ||
from haystack.document_stores.types import DuplicatePolicy | ||
from langfuse.decorators import observe | ||
from pydantic import BaseModel | ||
|
||
from src.core.pipeline import BasicPipeline | ||
from src.core.provider import DocumentStoreProvider, EmbedderProvider | ||
from src.pipelines.indexing import AsyncDocumentWriter, SqlPairsCleaner | ||
|
||
logger = logging.getLogger("wren-ai-service") | ||
|
||
|
||
class SqlPair(BaseModel): | ||
id: str | ||
sql: str | ||
question: str | ||
|
||
|
||
@component | ||
class SqlPairsConverter: | ||
@component.output_types(documents=List[Document]) | ||
def run(self, sql_pairs: List[SqlPair], project_id: Optional[str] = ""): | ||
logger.info(f"Project ID: {project_id} Converting SQL pairs to documents...") | ||
|
||
addition = {"project_id": project_id} if project_id else {} | ||
|
||
return { | ||
"documents": [ | ||
Document( | ||
id=str(uuid.uuid4()), | ||
meta={ | ||
"sql_pair_id": sql_pair.id, | ||
"sql": sql_pair.sql, | ||
**addition, | ||
}, | ||
content=sql_pair.question, | ||
) | ||
for sql_pair in sql_pairs | ||
] | ||
} | ||
|
||
|
||
## Start of Pipeline | ||
@observe(capture_input=False) | ||
def boilerplates( | ||
mdl_str: str, | ||
) -> Set[str]: | ||
mdl = orjson.loads(mdl_str) | ||
|
||
return { | ||
boilerplate.lower() | ||
for model in mdl.get("models", []) | ||
if (boilerplate := model.get("properties", {}).get("boilerplate")) | ||
} | ||
|
||
|
||
@observe(capture_input=False) | ||
def sql_pairs( | ||
boilerplates: Set[str], | ||
external_pairs: Dict[str, Any], | ||
) -> List[SqlPair]: | ||
return [ | ||
SqlPair( | ||
id=pair.get("id"), | ||
question=pair.get("question"), | ||
sql=pair.get("sql"), | ||
) | ||
for boilerplate in boilerplates | ||
if boilerplate in external_pairs | ||
for pair in external_pairs[boilerplate] | ||
] | ||
|
||
|
||
@observe(capture_input=False) | ||
def to_documents( | ||
sql_pairs: List[SqlPair], | ||
document_converter: SqlPairsConverter, | ||
project_id: Optional[str] = "", | ||
) -> Dict[str, Any]: | ||
return document_converter.run(sql_pairs=sql_pairs, project_id=project_id) | ||
|
||
|
||
@observe(capture_input=False, capture_output=False) | ||
async def embedding( | ||
to_documents: Dict[str, Any], | ||
embedder: Any, | ||
) -> Dict[str, Any]: | ||
return await embedder.run(documents=to_documents["documents"]) | ||
|
||
|
||
@observe(capture_input=False, capture_output=False) | ||
async def clean( | ||
cleaner: SqlPairsCleaner, | ||
sql_pairs: List[SqlPair], | ||
embedding: Dict[str, Any], | ||
project_id: Optional[str] = "", | ||
) -> Dict[str, Any]: | ||
sql_pair_ids = [sql_pair.id for sql_pair in sql_pairs] | ||
await cleaner.run(sql_pair_ids=sql_pair_ids, project_id=project_id) | ||
|
||
return embedding | ||
|
||
|
||
@observe(capture_input=False) | ||
async def write( | ||
clean: Dict[str, Any], | ||
writer: AsyncDocumentWriter, | ||
) -> None: | ||
return await writer.run(documents=clean["documents"]) | ||
|
||
|
||
## End of Pipeline | ||
|
||
|
||
def _load_sql_pairs(sql_pairs_path: str) -> Dict[str, Any]: | ||
if not sql_pairs_path: | ||
return {} | ||
|
||
if not os.path.exists(sql_pairs_path): | ||
logger.warning(f"SQL pairs file not found: {sql_pairs_path}") | ||
return {} | ||
|
||
try: | ||
with open(sql_pairs_path, "r") as file: | ||
return orjson.loads(file.read()) | ||
except Exception as e: | ||
logger.error(f"Error loading SQL pairs file: {e}") | ||
return {} | ||
|
||
|
||
class SqlPairs(BasicPipeline): | ||
def __init__( | ||
self, | ||
embedder_provider: EmbedderProvider, | ||
document_store_provider: DocumentStoreProvider, | ||
sql_pairs_path: Optional[str] = "sql_pairs.json", | ||
**kwargs, | ||
) -> None: | ||
store = document_store_provider.get_store(dataset_name="sql_pairs") | ||
|
||
self._components = { | ||
"cleaner": SqlPairsCleaner(store), | ||
"embedder": embedder_provider.get_document_embedder(), | ||
"document_converter": SqlPairsConverter(), | ||
"writer": AsyncDocumentWriter( | ||
document_store=store, | ||
policy=DuplicatePolicy.OVERWRITE, | ||
), | ||
"external_pairs": _load_sql_pairs(sql_pairs_path), | ||
} | ||
|
||
super().__init__( | ||
AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) | ||
) | ||
|
||
@observe(name="SQL Pairs Indexing") | ||
async def run( | ||
self, | ||
mdl_str: str, | ||
project_id: Optional[str] = "", | ||
) -> Dict[str, Any]: | ||
logger.info( | ||
f"Project ID: {project_id} SQL Pairs Indexing pipeline is running..." | ||
) | ||
|
||
return await self._pipe.execute( | ||
["write"], | ||
inputs={ | ||
"mdl_str": mdl_str, | ||
"project_id": project_id, | ||
**self._components, | ||
}, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
from src.pipelines.common import dry_run_pipeline | ||
|
||
dry_run_pipeline( | ||
SqlPairs, | ||
"sql_pairs_indexing", | ||
mdl_str='{"models": [{"properties": {"boilerplate": "hubspot"}}]}', | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure
id
field is not missing inSqlPair
instancesIn the
sql_pairs
function, theid
field is retrieved usingpair.get("id")
, which may returnNone
if the key is missing. SinceSqlPair
requiresid
to be a string, a missingid
could lead to validation errors or unintended behavior.Apply this diff to enforce the presence of the
id
field:Alternatively, provide a default
id
or handle cases whereid
might be missing to prevent potential errors.📝 Committable suggestion