Skip to content

Commit

Permalink
chore: expose the class from package level for generation pipe
Browse files Browse the repository at this point in the history
  • Loading branch information
paopa committed Dec 10, 2024
1 parent f8503d3 commit bf38415
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 33 deletions.
49 changes: 17 additions & 32 deletions wren-ai-service/src/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,6 @@
from src.core.pipeline import PipelineComponent
from src.core.provider import EmbedderProvider, LLMProvider
from src.pipelines import generation, indexing
from src.pipelines.generation import (
data_assistance,
followup_sql_generation,
intent_classification,
question_recommendation,
relationship_recommendation,
sql_answer,
sql_breakdown,
sql_correction,
sql_expansion,
sql_explanation,
sql_generation,
sql_regeneration,
sql_summary,
)
from src.pipelines.retrieval import historical_question, preprocess_sql_data, retrieval
from src.web.v1.services.ask import AskService
from src.web.v1.services.ask_details import AskDetailsService
Expand Down Expand Up @@ -91,10 +76,10 @@ def create_service_container(
),
ask_service=AskService(
pipelines={
"intent_classification": intent_classification.IntentClassification(
"intent_classification": generation.IntentClassification(
**pipe_components["intent_classification"],
),
"data_assistance": data_assistance.DataAssistance(
"data_assistance": generation.DataAssistance(
**pipe_components["data_assistance"]
),
"retrieval": retrieval.Retrieval(
Expand All @@ -106,16 +91,16 @@ def create_service_container(
"historical_question": historical_question.HistoricalQuestion(
**pipe_components["historical_question_retrieval"],
),
"sql_generation": sql_generation.SQLGeneration(
"sql_generation": generation.SQLGeneration(
**pipe_components["sql_generation"],
),
"sql_correction": sql_correction.SQLCorrection(
"sql_correction": generation.SQLCorrection(
**pipe_components["sql_correction"],
),
"followup_sql_generation": followup_sql_generation.FollowUpSQLGeneration(
"followup_sql_generation": generation.FollowUpSQLGeneration(
**pipe_components["followup_sql_generation"],
),
"sql_summary": sql_summary.SQLSummary(
"sql_summary": generation.SQLSummary(
**pipe_components["sql_summary"],
),
},
Expand All @@ -126,18 +111,18 @@ def create_service_container(
"preprocess_sql_data": preprocess_sql_data.PreprocessSqlData(
**pipe_components["preprocess_sql_data"],
),
"sql_answer": sql_answer.SQLAnswer(
"sql_answer": generation.SQLAnswer(
**pipe_components["sql_answer"],
),
},
**query_cache,
),
ask_details_service=AskDetailsService(
pipelines={
"sql_breakdown": sql_breakdown.SQLBreakdown(
"sql_breakdown": generation.SQLBreakdown(
**pipe_components["sql_breakdown"],
),
"sql_summary": sql_summary.SQLSummary(
"sql_summary": generation.SQLSummary(
**pipe_components["sql_summary"],
),
},
Expand All @@ -150,45 +135,45 @@ def create_service_container(
table_retrieval_size=settings.table_retrieval_size,
table_column_retrieval_size=settings.table_column_retrieval_size,
),
"sql_expansion": sql_expansion.SQLExpansion(
"sql_expansion": generation.SQLExpansion(
**pipe_components["sql_expansion"],
),
"sql_correction": sql_correction.SQLCorrection(
"sql_correction": generation.SQLCorrection(
**pipe_components["sql_correction"],
),
"sql_summary": sql_summary.SQLSummary(
"sql_summary": generation.SQLSummary(
**pipe_components["sql_summary"],
),
},
**query_cache,
),
sql_explanation_service=SQLExplanationService(
pipelines={
"sql_explanation": sql_explanation.SQLExplanation(
"sql_explanation": generation.SQLExplanation(
**pipe_components["sql_explanation"],
)
},
**query_cache,
),
sql_regeneration_service=SQLRegenerationService(
pipelines={
"sql_regeneration": sql_regeneration.SQLRegeneration(
"sql_regeneration": generation.SQLRegeneration(
**pipe_components["sql_regeneration"],
)
},
**query_cache,
),
relationship_recommendation=RelationshipRecommendation(
pipelines={
"relationship_recommendation": relationship_recommendation.RelationshipRecommendation(
"relationship_recommendation": generation.RelationshipRecommendation(
**pipe_components["relationship_recommendation"],
)
},
**query_cache,
),
question_recommendation=QuestionRecommendation(
pipelines={
"question_recommendation": question_recommendation.QuestionRecommendation(
"question_recommendation": generation.QuestionRecommendation(
**pipe_components["question_recommendation"],
),
"retrieval": retrieval.Retrieval(
Expand All @@ -197,7 +182,7 @@ def create_service_container(
table_column_retrieval_size=settings.table_column_retrieval_size,
allow_using_db_schemas_without_pruning=settings.allow_using_db_schemas_without_pruning,
),
"sql_generation": sql_generation.SQLGeneration(
"sql_generation": generation.SQLGeneration(
**pipe_components["sql_generation"],
),
},
Expand Down
30 changes: 29 additions & 1 deletion wren-ai-service/src/pipelines/generation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,31 @@
from .data_assistance import DataAssistance
from .followup_sql_generation import FollowUpSQLGeneration
from .intent_classification import IntentClassification
from .question_recommendation import QuestionRecommendation
from .relationship_recommendation import RelationshipRecommendation
from .semantics_enrichment import SemanticsEnrichment
from .sql_answer import SQLAnswer
from .sql_breakdown import SQLBreakdown
from .sql_correction import SQLCorrection
from .sql_expansion import SQLExpansion
from .sql_explanation import SQLExplanation
from .sql_generation import SQLGeneration
from .sql_regeneration import SQLRegeneration
from .sql_summary import SQLSummary

__all__ = ["SemanticsEnrichment"]
__all__ = [
"DataAssistance",
"FollowUpSQLGeneration",
"IntentClassification",
"QuestionRecommendation",
"RelationshipRecommendation",
"SemanticsEnrichment",
"SQLAnswer",
"SQLBreakdown",
"SQLCorrection",
"SQLExpansion",
"SQLExplanation",
"SQLGeneration",
"SQLRegeneration",
"SQLSummary",
]

0 comments on commit bf38415

Please sign in to comment.