diff --git a/wren-ai-service/src/globals.py b/wren-ai-service/src/globals.py index 7f492ed0d..d1d1d3b61 100644 --- a/wren-ai-service/src/globals.py +++ b/wren-ai-service/src/globals.py @@ -7,21 +7,6 @@ from src.core.pipeline import PipelineComponent from src.core.provider import EmbedderProvider, LLMProvider from src.pipelines import generation, indexing -from src.pipelines.generation import ( - data_assistance, - followup_sql_generation, - intent_classification, - question_recommendation, - relationship_recommendation, - sql_answer, - sql_breakdown, - sql_correction, - sql_expansion, - sql_explanation, - sql_generation, - sql_regeneration, - sql_summary, -) from src.pipelines.retrieval import historical_question, preprocess_sql_data, retrieval from src.web.v1.services.ask import AskService from src.web.v1.services.ask_details import AskDetailsService @@ -91,10 +76,10 @@ def create_service_container( ), ask_service=AskService( pipelines={ - "intent_classification": intent_classification.IntentClassification( + "intent_classification": generation.IntentClassification( **pipe_components["intent_classification"], ), - "data_assistance": data_assistance.DataAssistance( + "data_assistance": generation.DataAssistance( **pipe_components["data_assistance"] ), "retrieval": retrieval.Retrieval( @@ -106,16 +91,16 @@ def create_service_container( "historical_question": historical_question.HistoricalQuestion( **pipe_components["historical_question_retrieval"], ), - "sql_generation": sql_generation.SQLGeneration( + "sql_generation": generation.SQLGeneration( **pipe_components["sql_generation"], ), - "sql_correction": sql_correction.SQLCorrection( + "sql_correction": generation.SQLCorrection( **pipe_components["sql_correction"], ), - "followup_sql_generation": followup_sql_generation.FollowUpSQLGeneration( + "followup_sql_generation": generation.FollowUpSQLGeneration( **pipe_components["followup_sql_generation"], ), - "sql_summary": sql_summary.SQLSummary( + "sql_summary": generation.SQLSummary( **pipe_components["sql_summary"], ), }, @@ -126,7 +111,7 @@ def create_service_container( "preprocess_sql_data": preprocess_sql_data.PreprocessSqlData( **pipe_components["preprocess_sql_data"], ), - "sql_answer": sql_answer.SQLAnswer( + "sql_answer": generation.SQLAnswer( **pipe_components["sql_answer"], ), }, @@ -134,10 +119,10 @@ def create_service_container( ), ask_details_service=AskDetailsService( pipelines={ - "sql_breakdown": sql_breakdown.SQLBreakdown( + "sql_breakdown": generation.SQLBreakdown( **pipe_components["sql_breakdown"], ), - "sql_summary": sql_summary.SQLSummary( + "sql_summary": generation.SQLSummary( **pipe_components["sql_summary"], ), }, @@ -150,13 +135,13 @@ def create_service_container( table_retrieval_size=settings.table_retrieval_size, table_column_retrieval_size=settings.table_column_retrieval_size, ), - "sql_expansion": sql_expansion.SQLExpansion( + "sql_expansion": generation.SQLExpansion( **pipe_components["sql_expansion"], ), - "sql_correction": sql_correction.SQLCorrection( + "sql_correction": generation.SQLCorrection( **pipe_components["sql_correction"], ), - "sql_summary": sql_summary.SQLSummary( + "sql_summary": generation.SQLSummary( **pipe_components["sql_summary"], ), }, @@ -164,7 +149,7 @@ def create_service_container( ), sql_explanation_service=SQLExplanationService( pipelines={ - "sql_explanation": sql_explanation.SQLExplanation( + "sql_explanation": generation.SQLExplanation( **pipe_components["sql_explanation"], ) }, @@ -172,7 +157,7 @@ def create_service_container( ), sql_regeneration_service=SQLRegenerationService( pipelines={ - "sql_regeneration": sql_regeneration.SQLRegeneration( + "sql_regeneration": generation.SQLRegeneration( **pipe_components["sql_regeneration"], ) }, @@ -180,7 +165,7 @@ def create_service_container( ), relationship_recommendation=RelationshipRecommendation( pipelines={ - "relationship_recommendation": relationship_recommendation.RelationshipRecommendation( + "relationship_recommendation": generation.RelationshipRecommendation( **pipe_components["relationship_recommendation"], ) }, @@ -188,7 +173,7 @@ def create_service_container( ), question_recommendation=QuestionRecommendation( pipelines={ - "question_recommendation": question_recommendation.QuestionRecommendation( + "question_recommendation": generation.QuestionRecommendation( **pipe_components["question_recommendation"], ), "retrieval": retrieval.Retrieval( @@ -197,7 +182,7 @@ def create_service_container( table_column_retrieval_size=settings.table_column_retrieval_size, allow_using_db_schemas_without_pruning=settings.allow_using_db_schemas_without_pruning, ), - "sql_generation": sql_generation.SQLGeneration( + "sql_generation": generation.SQLGeneration( **pipe_components["sql_generation"], ), }, diff --git a/wren-ai-service/src/pipelines/generation/__init__.py b/wren-ai-service/src/pipelines/generation/__init__.py index e86c31801..8639d3092 100644 --- a/wren-ai-service/src/pipelines/generation/__init__.py +++ b/wren-ai-service/src/pipelines/generation/__init__.py @@ -1,3 +1,31 @@ +from .data_assistance import DataAssistance +from .followup_sql_generation import FollowUpSQLGeneration +from .intent_classification import IntentClassification +from .question_recommendation import QuestionRecommendation +from .relationship_recommendation import RelationshipRecommendation from .semantics_enrichment import SemanticsEnrichment +from .sql_answer import SQLAnswer +from .sql_breakdown import SQLBreakdown +from .sql_correction import SQLCorrection +from .sql_expansion import SQLExpansion +from .sql_explanation import SQLExplanation +from .sql_generation import SQLGeneration +from .sql_regeneration import SQLRegeneration +from .sql_summary import SQLSummary -__all__ = ["SemanticsEnrichment"] +__all__ = [ + "DataAssistance", + "FollowUpSQLGeneration", + "IntentClassification", + "QuestionRecommendation", + "RelationshipRecommendation", + "SemanticsEnrichment", + "SQLAnswer", + "SQLBreakdown", + "SQLCorrection", + "SQLExpansion", + "SQLExplanation", + "SQLGeneration", + "SQLRegeneration", + "SQLSummary", +]