Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(wren-ai-service): generate semantics for alias #976

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
2 changes: 1 addition & 1 deletion deployment/kustomizations/base/cm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ data:
- name: sql_regeneration
llm: litellm_llm.gpt-4o-mini-2024-07-18
engine: wren_ui
- name: semantics_description
- name: semantics_enrichment
llm: litellm_llm.gpt-4o-mini-2024-07-18
- name: relationship_recommendation
llm: litellm_llm.gpt-4o-mini-2024-07-18
Expand Down
2 changes: 1 addition & 1 deletion docker/config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ pipes:
- name: sql_regeneration
llm: litellm_llm.gpt-4o-mini-2024-07-18
engine: wren_ui
- name: semantics_description
- name: semantics_enrichment
llm: litellm_llm.gpt-4o-mini-2024-07-18
- name: relationship_recommendation
llm: litellm_llm.gpt-4o-mini-2024-07-18
Expand Down
82 changes: 29 additions & 53 deletions wren-ai-service/src/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,14 @@
from src.config import Settings
from src.core.pipeline import PipelineComponent
from src.core.provider import EmbedderProvider, LLMProvider
from src.pipelines import indexing
from src.pipelines.generation import (
chart_adjustment,
chart_generation,
data_assistance,
followup_sql_generation,
intent_classification,
question_recommendation,
relationship_recommendation,
semantics_description,
sql_answer,
sql_breakdown,
sql_correction,
sql_expansion,
sql_explanation,
sql_generation,
sql_regeneration,
sql_summary,
)
from src.pipelines.retrieval import (
historical_question,
preprocess_sql_data,
retrieval,
sql_executor,
)
from src.pipelines import generation, indexing, retrieval
from src.web.v1.services.ask import AskService
from src.web.v1.services.ask_details import AskDetailsService
from src.web.v1.services.chart import ChartService
from src.web.v1.services.chart_adjustment import ChartAdjustmentService
from src.web.v1.services.question_recommendation import QuestionRecommendation
from src.web.v1.services.relationship_recommendation import RelationshipRecommendation
from src.web.v1.services.semantics_description import SemanticsDescription
from src.web.v1.services.semantics_enrichment import SemanticsEnrichment
from src.web.v1.services.semantics_preparation import SemanticsPreparationService
from src.web.v1.services.sql_answer import SqlAnswerService
from src.web.v1.services.sql_expansion import SqlExpansionService
Expand All @@ -53,7 +29,7 @@ class ServiceContainer:
ask_details_service: AskDetailsService
question_recommendation: QuestionRecommendation
relationship_recommendation: RelationshipRecommendation
semantics_description: SemanticsDescription
semantics_enrichment: SemanticsEnrichment
semantics_preparation_service: SemanticsPreparationService
chart_service: ChartService
chart_adjustment_service: ChartAdjustmentService
Expand All @@ -78,10 +54,10 @@ def create_service_container(
"ttl": settings.query_cache_ttl,
}
return ServiceContainer(
semantics_description=SemanticsDescription(
semantics_enrichment=SemanticsEnrichment(
pipelines={
"semantics_description": semantics_description.SemanticsDescription(
**pipe_components["semantics_description"],
"semantics_enrichment": generation.SemanticsEnrichment(
**pipe_components["semantics_enrichment"],
)
},
**query_cache,
Expand All @@ -103,10 +79,10 @@ def create_service_container(
),
ask_service=AskService(
pipelines={
"intent_classification": intent_classification.IntentClassification(
"intent_classification": generation.IntentClassification(
**pipe_components["intent_classification"],
),
"data_assistance": data_assistance.DataAssistance(
"data_assistance": generation.DataAssistance(
**pipe_components["data_assistance"]
),
"retrieval": retrieval.Retrieval(
Expand All @@ -115,63 +91,63 @@ def create_service_container(
table_column_retrieval_size=settings.table_column_retrieval_size,
allow_using_db_schemas_without_pruning=settings.allow_using_db_schemas_without_pruning,
),
"historical_question": historical_question.HistoricalQuestion(
"historical_question": retrieval.HistoricalQuestion(
**pipe_components["historical_question_retrieval"],
),
"sql_generation": sql_generation.SQLGeneration(
"sql_generation": generation.SQLGeneration(
**pipe_components["sql_generation"],
),
"sql_correction": sql_correction.SQLCorrection(
"sql_correction": generation.SQLCorrection(
**pipe_components["sql_correction"],
),
"followup_sql_generation": followup_sql_generation.FollowUpSQLGeneration(
"followup_sql_generation": generation.FollowUpSQLGeneration(
**pipe_components["followup_sql_generation"],
),
"sql_summary": sql_summary.SQLSummary(
"sql_summary": generation.SQLSummary(
**pipe_components["sql_summary"],
),
},
**query_cache,
),
chart_service=ChartService(
pipelines={
"sql_executor": sql_executor.SQLExecutor(
"sql_executor": retrieval.SQLExecutor(
**pipe_components["sql_executor"],
),
"chart_generation": chart_generation.ChartGeneration(
"chart_generation": generation.ChartGeneration(
**pipe_components["chart_generation"],
),
},
**query_cache,
),
chart_adjustment_service=ChartAdjustmentService(
pipelines={
"sql_executor": sql_executor.SQLExecutor(
"sql_executor": retrieval.SQLExecutor(
**pipe_components["sql_executor"],
),
"chart_adjustment": chart_adjustment.ChartAdjustment(
"chart_adjustment": generation.ChartAdjustment(
**pipe_components["chart_adjustment"],
),
},
**query_cache,
),
sql_answer_service=SqlAnswerService(
pipelines={
"preprocess_sql_data": preprocess_sql_data.PreprocessSqlData(
"preprocess_sql_data": retrieval.PreprocessSqlData(
**pipe_components["preprocess_sql_data"],
),
"sql_answer": sql_answer.SQLAnswer(
"sql_answer": generation.SQLAnswer(
**pipe_components["sql_answer"],
),
},
**query_cache,
),
ask_details_service=AskDetailsService(
pipelines={
"sql_breakdown": sql_breakdown.SQLBreakdown(
"sql_breakdown": generation.SQLBreakdown(
**pipe_components["sql_breakdown"],
),
"sql_summary": sql_summary.SQLSummary(
"sql_summary": generation.SQLSummary(
**pipe_components["sql_summary"],
),
},
Expand All @@ -184,45 +160,45 @@ def create_service_container(
table_retrieval_size=settings.table_retrieval_size,
table_column_retrieval_size=settings.table_column_retrieval_size,
),
"sql_expansion": sql_expansion.SQLExpansion(
"sql_expansion": generation.SQLExpansion(
**pipe_components["sql_expansion"],
),
"sql_correction": sql_correction.SQLCorrection(
"sql_correction": generation.SQLCorrection(
**pipe_components["sql_correction"],
),
"sql_summary": sql_summary.SQLSummary(
"sql_summary": generation.SQLSummary(
**pipe_components["sql_summary"],
),
},
**query_cache,
),
sql_explanation_service=SQLExplanationService(
pipelines={
"sql_explanation": sql_explanation.SQLExplanation(
"sql_explanation": generation.SQLExplanation(
**pipe_components["sql_explanation"],
)
},
**query_cache,
),
sql_regeneration_service=SQLRegenerationService(
pipelines={
"sql_regeneration": sql_regeneration.SQLRegeneration(
"sql_regeneration": generation.SQLRegeneration(
**pipe_components["sql_regeneration"],
)
},
**query_cache,
),
relationship_recommendation=RelationshipRecommendation(
pipelines={
"relationship_recommendation": relationship_recommendation.RelationshipRecommendation(
"relationship_recommendation": generation.RelationshipRecommendation(
**pipe_components["relationship_recommendation"],
)
},
**query_cache,
),
question_recommendation=QuestionRecommendation(
pipelines={
"question_recommendation": question_recommendation.QuestionRecommendation(
"question_recommendation": generation.QuestionRecommendation(
**pipe_components["question_recommendation"],
),
"retrieval": retrieval.Retrieval(
Expand All @@ -231,7 +207,7 @@ def create_service_container(
table_column_retrieval_size=settings.table_column_retrieval_size,
allow_using_db_schemas_without_pruning=settings.allow_using_db_schemas_without_pruning,
),
"sql_generation": sql_generation.SQLGeneration(
"sql_generation": generation.SQLGeneration(
**pipe_components["sql_generation"],
),
},
Expand Down
35 changes: 35 additions & 0 deletions wren-ai-service/src/pipelines/generation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from .chart_adjustment import ChartAdjustment
from .chart_generation import ChartGeneration
from .data_assistance import DataAssistance
from .followup_sql_generation import FollowUpSQLGeneration
from .intent_classification import IntentClassification
from .question_recommendation import QuestionRecommendation
from .relationship_recommendation import RelationshipRecommendation
from .semantics_enrichment import SemanticsEnrichment
from .sql_answer import SQLAnswer
from .sql_breakdown import SQLBreakdown
from .sql_correction import SQLCorrection
from .sql_expansion import SQLExpansion
from .sql_explanation import SQLExplanation
from .sql_generation import SQLGeneration
from .sql_regeneration import SQLRegeneration
from .sql_summary import SQLSummary

__all__ = [
"ChartAdjustment",
"ChartGeneration",
"DataAssistance",
"FollowUpSQLGeneration",
"IntentClassification",
"QuestionRecommendation",
"RelationshipRecommendation",
"SemanticsEnrichment",
"SQLAnswer",
"SQLBreakdown",
"SQLCorrection",
"SQLExpansion",
"SQLExplanation",
"SQLGeneration",
"SQLRegeneration",
"SQLSummary",
]
Loading
Loading