diff --git a/deployment/kustomizations/base/cm.yaml b/deployment/kustomizations/base/cm.yaml index 145fbcd64..7e59f8f39 100644 --- a/deployment/kustomizations/base/cm.yaml +++ b/deployment/kustomizations/base/cm.yaml @@ -141,7 +141,7 @@ data: - name: sql_regeneration llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - - name: semantics_description + - name: semantics_enrichment llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: relationship_recommendation llm: litellm_llm.gpt-4o-mini-2024-07-18 diff --git a/docker/config.example.yaml b/docker/config.example.yaml index 8f2f8a804..fa03c339e 100644 --- a/docker/config.example.yaml +++ b/docker/config.example.yaml @@ -91,7 +91,7 @@ pipes: - name: sql_regeneration llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - - name: semantics_description + - name: semantics_enrichment llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: relationship_recommendation llm: litellm_llm.gpt-4o-mini-2024-07-18 diff --git a/wren-ai-service/src/globals.py b/wren-ai-service/src/globals.py index 3c8f5c758..335f50d8d 100644 --- a/wren-ai-service/src/globals.py +++ b/wren-ai-service/src/globals.py @@ -6,38 +6,14 @@ from src.config import Settings from src.core.pipeline import PipelineComponent from src.core.provider import EmbedderProvider, LLMProvider -from src.pipelines import indexing -from src.pipelines.generation import ( - chart_adjustment, - chart_generation, - data_assistance, - followup_sql_generation, - intent_classification, - question_recommendation, - relationship_recommendation, - semantics_description, - sql_answer, - sql_breakdown, - sql_correction, - sql_expansion, - sql_explanation, - sql_generation, - sql_regeneration, - sql_summary, -) -from src.pipelines.retrieval import ( - historical_question, - preprocess_sql_data, - retrieval, - sql_executor, -) +from src.pipelines import generation, indexing, retrieval from src.web.v1.services.ask import AskService from src.web.v1.services.ask_details import AskDetailsService from src.web.v1.services.chart import ChartService from src.web.v1.services.chart_adjustment import ChartAdjustmentService from src.web.v1.services.question_recommendation import QuestionRecommendation from src.web.v1.services.relationship_recommendation import RelationshipRecommendation -from src.web.v1.services.semantics_description import SemanticsDescription +from src.web.v1.services.semantics_enrichment import SemanticsEnrichment from src.web.v1.services.semantics_preparation import SemanticsPreparationService from src.web.v1.services.sql_answer import SqlAnswerService from src.web.v1.services.sql_expansion import SqlExpansionService @@ -53,7 +29,7 @@ class ServiceContainer: ask_details_service: AskDetailsService question_recommendation: QuestionRecommendation relationship_recommendation: RelationshipRecommendation - semantics_description: SemanticsDescription + semantics_enrichment: SemanticsEnrichment semantics_preparation_service: SemanticsPreparationService chart_service: ChartService chart_adjustment_service: ChartAdjustmentService @@ -78,10 +54,10 @@ def create_service_container( "ttl": settings.query_cache_ttl, } return ServiceContainer( - semantics_description=SemanticsDescription( + semantics_enrichment=SemanticsEnrichment( pipelines={ - "semantics_description": semantics_description.SemanticsDescription( - **pipe_components["semantics_description"], + "semantics_enrichment": generation.SemanticsEnrichment( + **pipe_components["semantics_enrichment"], ) }, **query_cache, @@ -103,10 +79,10 @@ def create_service_container( ), ask_service=AskService( pipelines={ - "intent_classification": intent_classification.IntentClassification( + "intent_classification": generation.IntentClassification( **pipe_components["intent_classification"], ), - "data_assistance": data_assistance.DataAssistance( + "data_assistance": generation.DataAssistance( **pipe_components["data_assistance"] ), "retrieval": retrieval.Retrieval( @@ -115,19 +91,19 @@ def create_service_container( table_column_retrieval_size=settings.table_column_retrieval_size, allow_using_db_schemas_without_pruning=settings.allow_using_db_schemas_without_pruning, ), - "historical_question": historical_question.HistoricalQuestion( + "historical_question": retrieval.HistoricalQuestion( **pipe_components["historical_question_retrieval"], ), - "sql_generation": sql_generation.SQLGeneration( + "sql_generation": generation.SQLGeneration( **pipe_components["sql_generation"], ), - "sql_correction": sql_correction.SQLCorrection( + "sql_correction": generation.SQLCorrection( **pipe_components["sql_correction"], ), - "followup_sql_generation": followup_sql_generation.FollowUpSQLGeneration( + "followup_sql_generation": generation.FollowUpSQLGeneration( **pipe_components["followup_sql_generation"], ), - "sql_summary": sql_summary.SQLSummary( + "sql_summary": generation.SQLSummary( **pipe_components["sql_summary"], ), }, @@ -135,10 +111,10 @@ def create_service_container( ), chart_service=ChartService( pipelines={ - "sql_executor": sql_executor.SQLExecutor( + "sql_executor": retrieval.SQLExecutor( **pipe_components["sql_executor"], ), - "chart_generation": chart_generation.ChartGeneration( + "chart_generation": generation.ChartGeneration( **pipe_components["chart_generation"], ), }, @@ -146,10 +122,10 @@ def create_service_container( ), chart_adjustment_service=ChartAdjustmentService( pipelines={ - "sql_executor": sql_executor.SQLExecutor( + "sql_executor": retrieval.SQLExecutor( **pipe_components["sql_executor"], ), - "chart_adjustment": chart_adjustment.ChartAdjustment( + "chart_adjustment": generation.ChartAdjustment( **pipe_components["chart_adjustment"], ), }, @@ -157,10 +133,10 @@ def create_service_container( ), sql_answer_service=SqlAnswerService( pipelines={ - "preprocess_sql_data": preprocess_sql_data.PreprocessSqlData( + "preprocess_sql_data": retrieval.PreprocessSqlData( **pipe_components["preprocess_sql_data"], ), - "sql_answer": sql_answer.SQLAnswer( + "sql_answer": generation.SQLAnswer( **pipe_components["sql_answer"], ), }, @@ -168,10 +144,10 @@ def create_service_container( ), ask_details_service=AskDetailsService( pipelines={ - "sql_breakdown": sql_breakdown.SQLBreakdown( + "sql_breakdown": generation.SQLBreakdown( **pipe_components["sql_breakdown"], ), - "sql_summary": sql_summary.SQLSummary( + "sql_summary": generation.SQLSummary( **pipe_components["sql_summary"], ), }, @@ -184,13 +160,13 @@ def create_service_container( table_retrieval_size=settings.table_retrieval_size, table_column_retrieval_size=settings.table_column_retrieval_size, ), - "sql_expansion": sql_expansion.SQLExpansion( + "sql_expansion": generation.SQLExpansion( **pipe_components["sql_expansion"], ), - "sql_correction": sql_correction.SQLCorrection( + "sql_correction": generation.SQLCorrection( **pipe_components["sql_correction"], ), - "sql_summary": sql_summary.SQLSummary( + "sql_summary": generation.SQLSummary( **pipe_components["sql_summary"], ), }, @@ -198,7 +174,7 @@ def create_service_container( ), sql_explanation_service=SQLExplanationService( pipelines={ - "sql_explanation": sql_explanation.SQLExplanation( + "sql_explanation": generation.SQLExplanation( **pipe_components["sql_explanation"], ) }, @@ -206,7 +182,7 @@ def create_service_container( ), sql_regeneration_service=SQLRegenerationService( pipelines={ - "sql_regeneration": sql_regeneration.SQLRegeneration( + "sql_regeneration": generation.SQLRegeneration( **pipe_components["sql_regeneration"], ) }, @@ -214,7 +190,7 @@ def create_service_container( ), relationship_recommendation=RelationshipRecommendation( pipelines={ - "relationship_recommendation": relationship_recommendation.RelationshipRecommendation( + "relationship_recommendation": generation.RelationshipRecommendation( **pipe_components["relationship_recommendation"], ) }, @@ -222,7 +198,7 @@ def create_service_container( ), question_recommendation=QuestionRecommendation( pipelines={ - "question_recommendation": question_recommendation.QuestionRecommendation( + "question_recommendation": generation.QuestionRecommendation( **pipe_components["question_recommendation"], ), "retrieval": retrieval.Retrieval( @@ -231,7 +207,7 @@ def create_service_container( table_column_retrieval_size=settings.table_column_retrieval_size, allow_using_db_schemas_without_pruning=settings.allow_using_db_schemas_without_pruning, ), - "sql_generation": sql_generation.SQLGeneration( + "sql_generation": generation.SQLGeneration( **pipe_components["sql_generation"], ), }, diff --git a/wren-ai-service/src/pipelines/generation/__init__.py b/wren-ai-service/src/pipelines/generation/__init__.py index e69de29bb..4a10ac239 100644 --- a/wren-ai-service/src/pipelines/generation/__init__.py +++ b/wren-ai-service/src/pipelines/generation/__init__.py @@ -0,0 +1,35 @@ +from .chart_adjustment import ChartAdjustment +from .chart_generation import ChartGeneration +from .data_assistance import DataAssistance +from .followup_sql_generation import FollowUpSQLGeneration +from .intent_classification import IntentClassification +from .question_recommendation import QuestionRecommendation +from .relationship_recommendation import RelationshipRecommendation +from .semantics_enrichment import SemanticsEnrichment +from .sql_answer import SQLAnswer +from .sql_breakdown import SQLBreakdown +from .sql_correction import SQLCorrection +from .sql_expansion import SQLExpansion +from .sql_explanation import SQLExplanation +from .sql_generation import SQLGeneration +from .sql_regeneration import SQLRegeneration +from .sql_summary import SQLSummary + +__all__ = [ + "ChartAdjustment", + "ChartGeneration", + "DataAssistance", + "FollowUpSQLGeneration", + "IntentClassification", + "QuestionRecommendation", + "RelationshipRecommendation", + "SemanticsEnrichment", + "SQLAnswer", + "SQLBreakdown", + "SQLCorrection", + "SQLExpansion", + "SQLExplanation", + "SQLGeneration", + "SQLRegeneration", + "SQLSummary", +] diff --git a/wren-ai-service/src/pipelines/generation/semantics_description.py b/wren-ai-service/src/pipelines/generation/semantics_enrichment.py similarity index 59% rename from wren-ai-service/src/pipelines/generation/semantics_description.py rename to wren-ai-service/src/pipelines/generation/semantics_enrichment.py index d85826f98..d824ae123 100644 --- a/wren-ai-service/src/pipelines/generation/semantics_description.py +++ b/wren-ai-service/src/pipelines/generation/semantics_enrichment.py @@ -17,7 +17,7 @@ ## Start of Pipeline @observe(capture_input=False) -def picked_models(mdl: dict, selected_models: list[str]) -> list[dict]: +def picked_models(mdl: dict) -> list[dict]: def relation_filter(column: dict) -> bool: return "relationship" not in column @@ -27,6 +27,7 @@ def column_formatter(columns: list[dict]) -> list[dict]: "name": column["name"], "type": column["type"], "properties": { + "alias": column["properties"].get("displayName", ""), "description": column["properties"].get("description", ""), }, } @@ -35,19 +36,17 @@ def column_formatter(columns: list[dict]) -> list[dict]: ] def extract(model: dict) -> dict: + prop = model["properties"] return { "name": model["name"], "columns": column_formatter(model["columns"]), "properties": { - "description": model["properties"].get("description", ""), + "alias": prop.get("displayName", ""), + "description": prop.get("description", ""), }, } - return [ - extract(model) - for model in mdl.get("models", []) - if model.get("name", "") in selected_models - ] + return [extract(model) for model in mdl.get("models", [])] @observe(capture_input=False) @@ -90,6 +89,7 @@ def wrapper(text: str) -> str: ## End of Pipeline class ModelProperties(BaseModel): + alias: str description: str @@ -108,76 +108,65 @@ class SemanticResult(BaseModel): models: list[SemanticModel] -SEMANTICS_DESCRIPTION_MODEL_KWARGS = { +semantics_enrichment_KWARGS = { "response_format": { "type": "json_schema", "json_schema": { - "name": "semantic_description", + "name": "semantics_enrichment", "schema": SemanticResult.model_json_schema(), }, } } system_prompt = """ -I have a data model represented in JSON format, with the following structure: - -``` -[ - {'name': 'model', 'columns': [ - {'name': 'column_1', 'type': 'type', 'properties': {} - }, - {'name': 'column_2', 'type': 'type', 'properties': {} - }, - {'name': 'column_3', 'type': 'type', 'properties': {} - } - ], 'properties': {} - } -] -``` - -Your task is to update this JSON structure by adding a `description` field inside both the `properties` attribute of each `column` and the `model` itself. -Each `description` should be derived from a user-provided input that explains the purpose or context of the `model` and its respective columns. -Follow these steps: -1. **For the `model`**: Prompt the user to provide a brief description of the model's overall purpose or its context. Insert this description in the `properties` field of the `model`. -2. **For each `column`**: Ask the user to describe each column's role or significance. Each column's description should be added under its respective `properties` field in the format: `'description': 'user-provided text'`. -3. Ensure that the output is a well-formatted JSON structure, preserving the input's original format and adding the appropriate `description` fields. - -### Output Format: - -``` +You are a data model expert. Your task is to enrich a JSON data model with descriptive metadata. + +Input Format: +[{ + 'name': 'model', + 'columns': [{'name': 'column', 'type': 'type', 'properties': {'alias': 'alias', 'description': 'description'}}], + 'properties': {'alias': 'alias', 'description': 'description'} +}] + +For each model and column, you will: +1. Add a clear, concise alias that serves as a business-friendly name +2. Add a detailed description explaining its purpose and usage + +Guidelines: +- Descriptions should be clear, concise and business-focused +- Aliases should be intuitive and user-friendly +- Use the user's context to inform the descriptions +- Maintain technical accuracy while being accessible to non-technical users + +Output Format: { - "models": [ - { + "models": [{ "name": "model", - "columns": [ - { - "name": "column_1", - "properties": { - "description": "" - } - }, - { - "name": "column_2", - "properties": { - "description": "" - } - }, - { - "name": "column_3", - "properties": { - "description": "" - } + "columns": [{ + "name": "column", + "properties": { + "alias": "User-friendly column name", + "description": "Clear explanation of column purpose" } - ], + }], "properties": { - "description": "" - } + "alias": "User-friendly model name", + "description": "Clear explanation of model purpose" } - ] + }] +} + +Example: +Input model "orders" with column "created_at" might become: +{ + "name": "created_at", + "properties": { + "alias": "Order Creation Date", + "description": "Timestamp when the order was first created in the system" + } } -``` -Make sure that the descriptions are concise, informative, and contextually appropriate based on the input provided by the user. +Focus on providing business value through clear, accurate descriptions while maintaining JSON structure integrity. """ user_prompt_template = """ @@ -186,17 +175,17 @@ class SemanticResult(BaseModel): Picked models: {{ picked_models }} Localization Language: {{ language }} -Please provide a brief description for the model and each column based on the user's prompt. +Please provide a brief description and alias for the model and each column based on the user's prompt. """ -class SemanticsDescription(BasicPipeline): +class SemanticsEnrichment(BasicPipeline): def __init__(self, llm_provider: LLMProvider, **_): self._components = { "prompt_builder": PromptBuilder(template=user_prompt_template), "generator": llm_provider.get_generator( system_prompt=system_prompt, - generation_kwargs=SEMANTICS_DESCRIPTION_MODEL_KWARGS, + generation_kwargs=semantics_enrichment_KWARGS, ), } self._final = "normalize" @@ -209,16 +198,13 @@ def __init__(self, llm_provider: LLMProvider, **_): async def run( self, user_prompt: str, - selected_models: list[str], mdl: dict, language: str = "en", ) -> dict: - logger.info("Semantics Description Generation pipeline is running...") return await self._pipe.execute( [self._final], inputs={ "user_prompt": user_prompt, - "selected_models": selected_models, "mdl": mdl, "language": language, **self._components, @@ -230,10 +216,9 @@ async def run( from src.pipelines.common import dry_run_pipeline dry_run_pipeline( - SemanticsDescription, - "semantics_description", + SemanticsEnrichment, + "semantics_enrichment", user_prompt="Track student enrollments, grades, and GPA calculations to monitor academic performance and identify areas for student support", - selected_models=[], mdl={}, language="en", ) diff --git a/wren-ai-service/src/pipelines/retrieval/__init__.py b/wren-ai-service/src/pipelines/retrieval/__init__.py index e69de29bb..65a653730 100644 --- a/wren-ai-service/src/pipelines/retrieval/__init__.py +++ b/wren-ai-service/src/pipelines/retrieval/__init__.py @@ -0,0 +1,6 @@ +from .historical_question import HistoricalQuestion +from .preprocess_sql_data import PreprocessSqlData +from .retrieval import Retrieval +from .sql_executor import SQLExecutor + +__all__ = ["HistoricalQuestion", "PreprocessSqlData", "Retrieval", "SQLExecutor"] diff --git a/wren-ai-service/src/web/v1/routers/__init__.py b/wren-ai-service/src/web/v1/routers/__init__.py index d3ebd402e..542dbcb63 100644 --- a/wren-ai-service/src/web/v1/routers/__init__.py +++ b/wren-ai-service/src/web/v1/routers/__init__.py @@ -7,7 +7,7 @@ chart_adjustment, question_recommendation, relationship_recommendation, - semantics_description, + semantics_enrichment, semantics_preparations, sql_answers, sql_expansions, @@ -20,7 +20,7 @@ router.include_router(ask_details.router) router.include_router(question_recommendation.router) router.include_router(relationship_recommendation.router) -router.include_router(semantics_description.router) +router.include_router(semantics_enrichment.router) router.include_router(semantics_preparations.router) router.include_router(sql_answers.router) router.include_router(sql_expansions.router) diff --git a/wren-ai-service/src/web/v1/routers/semantics_description.py b/wren-ai-service/src/web/v1/routers/semantics_enrichment.py similarity index 71% rename from wren-ai-service/src/web/v1/routers/semantics_description.py rename to wren-ai-service/src/web/v1/routers/semantics_enrichment.py index b8d47bcbf..0068187a4 100644 --- a/wren-ai-service/src/web/v1/routers/semantics_description.py +++ b/wren-ai-service/src/web/v1/routers/semantics_enrichment.py @@ -12,18 +12,18 @@ get_service_metadata, ) from src.web.v1.services import Configuration -from src.web.v1.services.semantics_description import SemanticsDescription +from src.web.v1.services.semantics_enrichment import SemanticsEnrichment router = APIRouter() """ -Semantics Description Router +Semantics Enrichment Router -This router handles endpoints related to generating and retrieving semantic descriptions. +This router handles endpoints related to generating and retrieving semantics enrichment for data models. Endpoints: -1. POST /semantics-descriptions - - Generates a new semantic description +1. POST /semantics-enrichment + - Generates a new semantics enrichment task for data models - Request body: PostRequest { "selected_models": ["model1", "model2"], # List of model names to describe @@ -31,7 +31,7 @@ "mdl": "{ ... }", # JSON string of the MDL (Model Definition Language) "project_id": "project-id", # Optional project ID "configuration": { # Optional configuration settings - "language": "English" # Optional language, defaults to "English" + "language": "en" # Optional language, defaults to "en" } } - Response: PostResponse @@ -39,42 +39,36 @@ "id": "unique-uuid" # Unique identifier for the generated description } -2. GET /semantics-descriptions/{id} - - Retrieves the status and result of a semantic description generation +2. GET /semantics-enrichment/{id} + - Retrieves the status and result of a semantics enrichment generation - Path parameter: id (str) - Response: GetResponse { "id": "unique-uuid", # Unique identifier of the description "status": "generating" | "finished" | "failed", - "response": [ # Present only if status is "finished" or "generating" - { - "name": "model1", - "columns": [ - { - "name": "col1", - "description": "Unique identifier for each record in the example model." - } - ], - "description": "This model is used for analysis purposes, capturing key attributes of records." - }, - { - "name": "model2", - "columns": [ - { - "name": "col1", - "description": "Unique identifier for each record in the example model." - } - ], - "description": "This model is used for analysis purposes, capturing key attributes of records." - } - ], + "response": { # Present only if status is "finished" or "generating" + "models": [ + { + "name": "model1", + "columns": [ + { + "name": "col1", + "displayName": "col1_alias", + "description": "Unique identifier for each record in the example model." + } + ], + "displayName": "model1_alias", + "description": "This model is used for analysis purposes, capturing key attributes of records." + } + ] + }, "error": { # Present only if status is "failed" "code": "OTHERS", "message": "Error description" } } -The semantic description generation is an asynchronous process. The POST endpoint +The semantics enrichment generation is an asynchronous process. The POST endpoint initiates the generation and returns immediately with an ID. The GET endpoint can then be used to check the status and retrieve the result when it's ready. @@ -98,9 +92,14 @@ class PostResponse(BaseModel): id: str +@router.post( + "/semantics-enrichment", + response_model=PostResponse, +) @router.post( "/semantics-descriptions", response_model=PostResponse, + deprecated=True, ) async def generate( request: PostRequest, @@ -109,10 +108,10 @@ async def generate( service_metadata: ServiceMetadata = Depends(get_service_metadata), ) -> PostResponse: id = str(uuid.uuid4()) - service = service_container.semantics_description + service = service_container.semantics_enrichment - service[id] = SemanticsDescription.Resource(id=id) - input = SemanticsDescription.Input( + service[id] = SemanticsEnrichment.Resource(id=id) + input = SemanticsEnrichment.Input( id=id, selected_models=request.selected_models, user_prompt=request.user_prompt, @@ -134,15 +133,20 @@ class GetResponse(BaseModel): error: Optional[dict] +@router.get( + "/semantics-enrichment/{id}", + response_model=GetResponse, +) @router.get( "/semantics-descriptions/{id}", response_model=GetResponse, + deprecated=True, ) async def get( id: str, service_container: ServiceContainer = Depends(get_service_container), ) -> GetResponse: - resource = service_container.semantics_description[id] + resource = service_container.semantics_enrichment[id] def _formatter(response: Optional[dict]) -> Optional[list[dict]]: if response is None: @@ -154,10 +158,12 @@ def _formatter(response: Optional[dict]) -> Optional[list[dict]]: "columns": [ { "name": column["name"], + "displayName": column["properties"].get("alias", ""), "description": column["properties"].get("description", ""), } for column in model_data["columns"] ], + "displayName": model_data["properties"].get("alias", ""), "description": model_data["properties"].get("description", ""), } for model_name, model_data in response.items() diff --git a/wren-ai-service/src/web/v1/services/semantics_description.py b/wren-ai-service/src/web/v1/services/semantics_enrichment.py similarity index 81% rename from wren-ai-service/src/web/v1/services/semantics_description.py rename to wren-ai-service/src/web/v1/services/semantics_enrichment.py index 040f333dc..8ad0e1af1 100644 --- a/wren-ai-service/src/web/v1/services/semantics_description.py +++ b/wren-ai-service/src/web/v1/services/semantics_enrichment.py @@ -14,7 +14,7 @@ logger = logging.getLogger("wren-ai-service") -class SemanticsDescription: +class SemanticsEnrichment: class Input(BaseModel): id: str selected_models: list[str] @@ -40,7 +40,7 @@ def __init__( ttl: int = 120, ): self._pipelines = pipelines - self._cache: Dict[str, SemanticsDescription.Resource] = TTLCache( + self._cache: Dict[str, SemanticsEnrichment.Resource] = TTLCache( maxsize=maxsize, ttl=ttl ) @@ -55,7 +55,7 @@ def _handle_exception( status="failed", error=self.Resource.Error(code=code, message=error_message), ) - logger.error(error_message) + logger.error(f"Project ID: {request.project_id}, {error_message}") def _chunking( self, mdl_dict: dict, request: Input, chunk_size: int = 50 @@ -65,27 +65,23 @@ def _chunking( "language": request.configuration.language, } + def _model_picker(model: dict, selected: list[str]) -> bool: + return model["name"] in selected or "*" in selected + chunks = [ { **model, "columns": model["columns"][i : i + chunk_size], } for model in mdl_dict["models"] - if model["name"] in request.selected_models + if _model_picker(model, request.selected_models) for i in range(0, len(model["columns"]), chunk_size) ] - return [ - { - **template, - "mdl": {"models": [chunk]}, - "selected_models": [chunk["name"]], - } - for chunk in chunks - ] + return [{**template, "mdl": {"models": [chunk]}} for chunk in chunks] async def _generate_task(self, request_id: str, chunk: dict): - resp = await self._pipelines["semantics_description"].run(**chunk) + resp = await self._pipelines["semantics_enrichment"].run(**chunk) normalize = resp.get("normalize") current = self[request_id] @@ -98,10 +94,12 @@ async def _generate_task(self, request_id: str, chunk: dict): current.response[key]["columns"].extend(normalize[key]["columns"]) - @observe(name="Generate Semantics Description") + @observe(name="Enrich Semantics") @trace_metadata async def generate(self, request: Input, **kwargs) -> Resource: - logger.info("Generate Semantics Description pipeline is running...") + logger.info( + f"Project ID: {request.project_id}, Enrich Semantics pipeline is running..." + ) try: mdl_dict = orjson.loads(request.mdl) @@ -121,7 +119,7 @@ async def generate(self, request: Input, **kwargs) -> Resource: except Exception as e: self._handle_exception( request, - f"An error occurred during semantics description generation: {str(e)}", + f"An error occurred during semantics enrichment: {str(e)}", ) return self[request.id].with_metadata() @@ -130,7 +128,7 @@ def __getitem__(self, id: str) -> Resource: response = self._cache.get(id) if response is None: - message = f"Semantics Description Resource with ID '{id}' not found." + message = f"Semantics Enrichment Resource with ID '{id}' not found." logger.exception(message) return self.Resource( id=id, diff --git a/wren-ai-service/tests/data/config.test.yaml b/wren-ai-service/tests/data/config.test.yaml index 006f4d15c..e583bb65b 100644 --- a/wren-ai-service/tests/data/config.test.yaml +++ b/wren-ai-service/tests/data/config.test.yaml @@ -70,7 +70,7 @@ pipes: - name: sql_regeneration llm: openai_llm.gpt-4o-mini engine: wren_ui - - name: semantics_description + - name: semantics_enrichment llm: openai_llm.gpt-4o-mini - name: relationship_recommendation llm: openai_llm.gpt-4o-mini diff --git a/wren-ai-service/tests/pytest/services/test_semantics_description.py b/wren-ai-service/tests/pytest/services/test_semantics_enrichment.py similarity index 57% rename from wren-ai-service/tests/pytest/services/test_semantics_description.py rename to wren-ai-service/tests/pytest/services/test_semantics_enrichment.py index dc48e7339..2d9c31f40 100644 --- a/wren-ai-service/tests/pytest/services/test_semantics_description.py +++ b/wren-ai-service/tests/pytest/services/test_semantics_enrichment.py @@ -4,7 +4,7 @@ import orjson import pytest -from src.web.v1.services.semantics_description import SemanticsDescription +from src.web.v1.services.semantics_enrichment import SemanticsEnrichment @pytest.fixture @@ -13,22 +13,35 @@ def service(): mock_pipeline.run.return_value = { "normalize": { "model1": { - "columns": [], - "properties": {"description": "Test description"}, + "columns": [ + { + "name": "column1", + "type": "varchar", + "notNull": False, + "properties": { + "description": "Test description", + "alias": "column1_alias", + }, + } + ], + "properties": { + "description": "Test description", + "alias": "model1_alias", + }, } } } - pipelines = {"semantics_description": mock_pipeline} - return SemanticsDescription(pipelines=pipelines) + pipelines = {"semantics_enrichment": mock_pipeline} + return SemanticsEnrichment(pipelines=pipelines) @pytest.mark.asyncio -async def test_generate_semantics_description( - service: SemanticsDescription, +async def test_generate_semantics_enrichment( + service: SemanticsEnrichment, ): - service["test_id"] = SemanticsDescription.Resource(id="test_id") - request = SemanticsDescription.Input( + service["test_id"] = SemanticsEnrichment.Resource(id="test_id") + request = SemanticsEnrichment.Input( id="test_id", user_prompt="Describe the model", selected_models=["model1"], @@ -42,19 +55,32 @@ async def test_generate_semantics_description( assert response.status == "finished" assert response.response == { "model1": { - "columns": [], - "properties": {"description": "Test description"}, + "columns": [ + { + "name": "column1", + "type": "varchar", + "notNull": False, + "properties": { + "description": "Test description", + "alias": "column1_alias", + }, + } + ], + "properties": { + "description": "Test description", + "alias": "model1_alias", + }, } } assert response.error is None @pytest.mark.asyncio -async def test_generate_semantics_description_with_invalid_mdl( - service: SemanticsDescription, +async def test_generate_semantics_enrichment_with_invalid_mdl( + service: SemanticsEnrichment, ): - service["test_id"] = SemanticsDescription.Resource(id="test_id") - request = SemanticsDescription.Input( + service["test_id"] = SemanticsEnrichment.Resource(id="test_id") + request = SemanticsEnrichment.Input( id="test_id", user_prompt="Describe the model", selected_models=["model1"], @@ -72,18 +98,18 @@ async def test_generate_semantics_description_with_invalid_mdl( @pytest.mark.asyncio -async def test_generate_semantics_description_with_exception( - service: SemanticsDescription, +async def test_generate_semantics_enrichment_with_exception( + service: SemanticsEnrichment, ): - service["test_id"] = SemanticsDescription.Resource(id="test_id") - request = SemanticsDescription.Input( + service["test_id"] = SemanticsEnrichment.Resource(id="test_id") + request = SemanticsEnrichment.Input( id="test_id", user_prompt="Describe the model", selected_models=["model1"], mdl='{"models": [{"name": "model1", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}]}', ) - service._pipelines["semantics_description"].run.side_effect = Exception( + service._pipelines["semantics_enrichment"].run.side_effect = Exception( "Test exception" ) @@ -94,16 +120,13 @@ async def test_generate_semantics_description_with_exception( assert response.status == "failed" assert response.response is None assert response.error.code == "OTHERS" - assert ( - "An error occurred during semantics description generation" - in response.error.message - ) + assert "An error occurred during semantics enrichment:" in response.error.message -def test_get_semantics_description_result( - service: SemanticsDescription, +def test_get_semantics_enrichment_result( + service: SemanticsEnrichment, ): - expected_response = SemanticsDescription.Resource( + expected_response = SemanticsEnrichment.Resource( id="test_id", status="finished", response={"model1": {"description": "Test description"}}, @@ -115,8 +138,8 @@ def test_get_semantics_description_result( assert result == expected_response -def test_get_non_existent_semantics_description_result( - service: SemanticsDescription, +def test_get_non_existent_semantics_enrichment_result( + service: SemanticsEnrichment, ): result = service["non_existent_id"] @@ -129,10 +152,10 @@ def test_get_non_existent_semantics_description_result( @pytest.mark.asyncio async def test_batch_processing_with_multiple_models( - service: SemanticsDescription, + service: SemanticsEnrichment, ): - service["test_id"] = SemanticsDescription.Resource(id="test_id") - request = SemanticsDescription.Input( + service["test_id"] = SemanticsEnrichment.Resource(id="test_id") + request = SemanticsEnrichment.Input( id="test_id", user_prompt="Describe the models", selected_models=["model1", "model2", "model3"], @@ -140,7 +163,7 @@ async def test_batch_processing_with_multiple_models( ) # Mock pipeline responses for each chunk - service._pipelines["semantics_description"].run.side_effect = [ + service._pipelines["semantics_enrichment"].run.side_effect = [ {"normalize": {"model1": {"description": "Description 1"}}}, {"normalize": {"model2": {"description": "Description 2"}}}, {"normalize": {"model3": {"description": "Description 3"}}}, @@ -161,45 +184,79 @@ async def test_batch_processing_with_multiple_models( assert len(chunks) == 3 # Default chunk_size=1 assert all("user_prompt" in chunk for chunk in chunks) assert all("mdl" in chunk for chunk in chunks) - assert [len(chunk["selected_models"]) for chunk in chunks] == [1, 1, 1] def test_batch_processing_with_custom_chunk_size( - service: SemanticsDescription, + service: SemanticsEnrichment, ): - service["test_id"] = SemanticsDescription.Resource(id="test_id") - request = SemanticsDescription.Input( + test_mdl = { + "models": [ + { + "name": "model1", + "columns": [{"name": "column1", "type": "varchar", "notNull": False}], + }, + { + "name": "model2", + "columns": [{"name": "column1", "type": "varchar", "notNull": False}], + }, + { + "name": "model3", + "columns": [{"name": "column1", "type": "varchar", "notNull": False}], + }, + { + "name": "model4", + "columns": [ + {"name": "column1", "type": "varchar", "notNull": False}, + {"name": "column2", "type": "varchar", "notNull": False}, + ], + }, + ] + } + service["test_id"] = SemanticsEnrichment.Resource(id="test_id") + request = SemanticsEnrichment.Input( id="test_id", user_prompt="Describe the models", selected_models=["model1", "model2", "model3", "model4"], - mdl='{"models": [{"name": "model1", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}, {"name": "model2", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}, {"name": "model3", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}, {"name": "model4", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}]}', + mdl=orjson.dumps(test_mdl), ) # Test chunking with custom chunk size - chunks = service._chunking(orjson.loads(request.mdl), request, chunk_size=2) + chunks = service._chunking(orjson.loads(request.mdl), request, chunk_size=1) - assert len(chunks) == 4 - assert [len(chunk["selected_models"]) for chunk in chunks] == [1, 1, 1, 1] - assert chunks[0]["selected_models"] == ["model1"] - assert chunks[1]["selected_models"] == ["model2"] - assert chunks[2]["selected_models"] == ["model3"] - assert chunks[3]["selected_models"] == ["model4"] + assert len(chunks) == 5 + assert chunks[0]["mdl"]["models"][0]["name"] == "model1" + assert chunks[1]["mdl"]["models"][0]["name"] == "model2" + assert chunks[2]["mdl"]["models"][0]["name"] == "model3" + assert chunks[3]["mdl"]["models"][0]["name"] == "model4" + assert chunks[4]["mdl"]["models"][0]["name"] == "model4" @pytest.mark.asyncio async def test_batch_processing_partial_failure( - service: SemanticsDescription, + service: SemanticsEnrichment, ): - service["test_id"] = SemanticsDescription.Resource(id="test_id") - request = SemanticsDescription.Input( + test_mdl = { + "models": [ + { + "name": "model1", + "columns": [{"name": "column1", "type": "varchar", "notNull": False}], + }, + { + "name": "model2", + "columns": [{"name": "column1", "type": "varchar", "notNull": False}], + }, + ] + } + service["test_id"] = SemanticsEnrichment.Resource(id="test_id") + request = SemanticsEnrichment.Input( id="test_id", user_prompt="Describe the models", selected_models=["model1", "model2"], - mdl='{"models": [{"name": "model1", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}, {"name": "model2", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}]}', + mdl=orjson.dumps(test_mdl), ) # Mock first chunk succeeds, second chunk fails - service._pipelines["semantics_description"].run.side_effect = [ + service._pipelines["semantics_enrichment"].run.side_effect = [ {"normalize": {"model1": {"description": "Description 1"}}}, Exception("Failed processing model2"), ] @@ -215,12 +272,12 @@ async def test_batch_processing_partial_failure( @pytest.mark.asyncio async def test_concurrent_updates_no_race_condition( - service: SemanticsDescription, + service: SemanticsEnrichment, ): test_id = "concurrent_test" - service[test_id] = SemanticsDescription.Resource(id=test_id) + service[test_id] = SemanticsEnrichment.Resource(id=test_id) - request = SemanticsDescription.Input( + request = SemanticsEnrichment.Input( id=test_id, user_prompt="Test concurrent updates", selected_models=["model1", "model2", "model3", "model4", "model5"], @@ -236,7 +293,7 @@ async def delayed_response(model_num, delay=0.1): } } - service._pipelines["semantics_description"].run.side_effect = [ + service._pipelines["semantics_enrichment"].run.side_effect = [ await delayed_response(1), await delayed_response(2), await delayed_response(3), diff --git a/wren-ai-service/tools/config/config.example.yaml b/wren-ai-service/tools/config/config.example.yaml index 2ba4c8ecc..4eb94500d 100644 --- a/wren-ai-service/tools/config/config.example.yaml +++ b/wren-ai-service/tools/config/config.example.yaml @@ -105,7 +105,7 @@ pipes: - name: sql_regeneration llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - - name: semantics_description + - name: semantics_enrichment llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: relationship_recommendation llm: litellm_llm.gpt-4o-mini-2024-07-18 diff --git a/wren-ai-service/tools/config/config.full.yaml b/wren-ai-service/tools/config/config.full.yaml index 742cafaaf..28e823de6 100644 --- a/wren-ai-service/tools/config/config.full.yaml +++ b/wren-ai-service/tools/config/config.full.yaml @@ -124,7 +124,7 @@ pipes: - name: sql_regeneration llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - - name: semantics_description + - name: semantics_enrichment llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: relationship_recommendation llm: litellm_llm.gpt-4o-mini-2024-07-18