diff --git a/wren-ai-service/src/pipelines/indexing/sql_pairs.py b/wren-ai-service/src/pipelines/indexing/sql_pairs.py index b8616198d..c5c322d30 100644 --- a/wren-ai-service/src/pipelines/indexing/sql_pairs.py +++ b/wren-ai-service/src/pipelines/indexing/sql_pairs.py @@ -2,7 +2,7 @@ import os import sys import uuid -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Set import orjson from hamilton import base @@ -53,7 +53,7 @@ def run(self, sql_pairs: List[SqlPair], project_id: Optional[str] = ""): @observe(capture_input=False) def boilerplates( mdl_str: str, -) -> List[str]: +) -> Set[str]: mdl = orjson.loads(mdl_str) return { @@ -129,8 +129,12 @@ def _load_sql_pairs(sql_pairs_path: str) -> Dict[str, Any]: logger.warning(f"SQL pairs file not found: {sql_pairs_path}") return {} - with open(sql_pairs_path, "r") as file: - return orjson.loads(file.read()) + try: + with open(sql_pairs_path, "r") as file: + return orjson.loads(file.read()) + except Exception as e: + logger.error(f"Error loading SQL pairs file: {e}") + return {} class SqlPairs(BasicPipeline): diff --git a/wren-ai-service/tools/mdl_to_str.py b/wren-ai-service/tools/mdl_to_str.py index cfab8f8bf..d7561e738 100644 --- a/wren-ai-service/tools/mdl_to_str.py +++ b/wren-ai-service/tools/mdl_to_str.py @@ -1,3 +1,5 @@ +import argparse + import orjson @@ -38,8 +40,6 @@ def _args(): if __name__ == "__main__": - import argparse - args = _args() mdl = orjson.loads(open(args.path).read()) print(to_str(mdl))