diff --git a/wren-ai-service/src/globals.py b/wren-ai-service/src/globals.py index d1d1d3b61..18b6e19ef 100644 --- a/wren-ai-service/src/globals.py +++ b/wren-ai-service/src/globals.py @@ -6,8 +6,7 @@ from src.config import Settings from src.core.pipeline import PipelineComponent from src.core.provider import EmbedderProvider, LLMProvider -from src.pipelines import generation, indexing -from src.pipelines.retrieval import historical_question, preprocess_sql_data, retrieval +from src.pipelines import generation, indexing, retrieval from src.web.v1.services.ask import AskService from src.web.v1.services.ask_details import AskDetailsService from src.web.v1.services.question_recommendation import QuestionRecommendation @@ -88,7 +87,7 @@ def create_service_container( table_column_retrieval_size=settings.table_column_retrieval_size, allow_using_db_schemas_without_pruning=settings.allow_using_db_schemas_without_pruning, ), - "historical_question": historical_question.HistoricalQuestion( + "historical_question": retrieval.HistoricalQuestion( **pipe_components["historical_question_retrieval"], ), "sql_generation": generation.SQLGeneration( @@ -108,7 +107,7 @@ def create_service_container( ), sql_answer_service=SqlAnswerService( pipelines={ - "preprocess_sql_data": preprocess_sql_data.PreprocessSqlData( + "preprocess_sql_data": retrieval.PreprocessSqlData( **pipe_components["preprocess_sql_data"], ), "sql_answer": generation.SQLAnswer( diff --git a/wren-ai-service/src/pipelines/retrieval/__init__.py b/wren-ai-service/src/pipelines/retrieval/__init__.py index e69de29bb..e89fde241 100644 --- a/wren-ai-service/src/pipelines/retrieval/__init__.py +++ b/wren-ai-service/src/pipelines/retrieval/__init__.py @@ -0,0 +1,5 @@ +from .historical_question import HistoricalQuestion +from .preprocess_sql_data import PreprocessSqlData +from .retrieval import Retrieval + +__all__ = ["HistoricalQuestion", "PreprocessSqlData", "Retrieval"]