Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(wren-ai-service): Try to make evaluation work again #1085

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

tedyyan
Copy link
Contributor

@tedyyan tedyyan commented Jan 5, 2025

After the code is refactored, the evaluation will not work.

This PR is trying to make the evaluation work again.

But now it reports the error:
looks like engine doesn't load spider test data. I follow the instruction of https://github.com/Canner/WrenAI/tree/main/wren-ai-service/eval#eval-dataset-preparationif-using-spider-10-dataset put the data to wren-ai-service/tools/dev/spider1.0I

{
    "valid_generation_results": [],
    "invalid_generation_results": [
        {
            "sql": "SELECT \"c\".\"Continent\", COUNT(\"cm\".\"Id\") AS \"MakerCount\" FROM \"continents\" AS \"c\" LEFT JOIN \"countries\" AS \"co\" ON \"c\".\"ContId\" = \"co\".\"Continent\" LEFT JOIN \"car_makers\" AS \"cm\" ON LOWER(\"co\".\"CountryName\") = LOWER(\"cm\".\"Country\") GROUP BY \"c\".\"Continent\"",
            "type": "DRY_RUN",
            "error": "Cannot read properties of null (reading 'id')",
            "correlation_id": null
        }
    ]
}

Summary by CodeRabbit

Summary by CodeRabbit

  • New Features

    • Enhanced pipeline configuration management
    • Improved service metadata retrieval
    • Streamlined prediction process
    • Added new methods for handling queries in the AskPipeline
  • Refactor

    • Restructured AskPipeline class
    • Updated provider and component initialization
    • Modified prediction logic to focus on first dataset item
    • Adjusted SQL execution parameter handling
  • Bug Fixes

    • Corrected handling of project_id type in SQL execution
  • Documentation

    • Improved code organization and clarity

Copy link
Contributor

coderabbitai bot commented Jan 5, 2025

Warning

Rate limit exceeded

@tedyyan has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 5 minutes and 49 seconds before requesting another review.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

📥 Commits

Reviewing files that changed from the base of the PR and between 2c7ab2c and 92877b2.

📒 Files selected for processing (3)
  • wren-ai-service/eval/pipelines.py (7 hunks)
  • wren-ai-service/eval/prediction.py (4 hunks)
  • wren-ai-service/src/globals.py (1 hunks)

Walkthrough

The pull request introduces substantial modifications to the AI service's evaluation and prediction pipelines. Key changes include the addition of the AskPipeline class with new methods for indexing and asking, a restructured constructor, and updates to configuration management. The prediction logic has been streamlined to focus on individual dataset items, while the ServiceMetadata class now includes a method for attribute retrieval. Overall, these changes enhance the functionality and organization of the service's components.

Changes

File Change Summary
wren-ai-service/eval/pipelines.py - Added new methods: indexing_service, ask_service, dict_to_string
- Restructured AskPipeline constructor to accept new parameters
- Modified _process method for new service components
- Updated init function to support new parameters
wren-ai-service/eval/prediction.py - Updated import statements
- Replaced environment variable access with settings
- Modified provider initialization and return variable
- Changed prediction logic to process only the first dataset item
- Added generate_components method
wren-ai-service/src/globals.py - Added get method to ServiceMetadata class for attribute retrieval
wren-ai-service/src/pipelines/common.py - Updated project_id parameter type in _classify_invalid_generation_results method
- Removed List[List[str]] type from run method signature

Sequence Diagram

sequenceDiagram
    participant User
    participant AskPipeline
    participant IndexingService
    participant AskService

    User->>AskPipeline: Initialize with metadata
    AskPipeline->>IndexingService: Prepare semantics
    IndexingService-->>AskPipeline: Semantics ready
    AskPipeline->>AskService: Create AskRequest
    AskService-->>AskPipeline: Process query
    AskPipeline-->>User: Return SQL output
Loading

Poem

🐰 A Rabbit's Ode to Pipeline Prose 🔧

In code's grand dance, a pipeline springs,
With services that softly sing,
Metadata flows, components gleam,
Transforming how our systems dream!

Hop, hop, hooray for smart design! 🚀


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR. (Beta)
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (11)
wren-ai-service/eval/pipelines.py (9)

15-15: Unused import detected.
The settings import doesn't appear to be used anywhere in this file. Consider removing it to keep the imports clean.

- from src.config import settings
🧰 Tools
🪛 Ruff (0.8.2)

15-15: src.config.settings imported but unused

Remove unused import: src.config.settings

(F401)


16-16: Unused import detected.
The generate_components import is unused in this file. Consider removing it to reduce clutter.

- from src.providers import generate_components
🧰 Tools
🪛 Ruff (0.8.2)

16-16: src.providers.generate_components imported but unused

Remove unused import: src.providers.generate_components

(F401)


17-26: Potentially redundant import.
All imported symbols here (AskRequest, AskResultRequest, AskResultResponse, AskService, etc.) are used except possibly AskResultResponse. Since the code calls ask_service_var.get_ask_result(...) (which returns an object of type AskResultResponse), you may keep it for clarity or for type hints. Otherwise, if it's not referenced directly (e.g., in annotations), removing it can reduce confusion.

🧰 Tools
🪛 Ruff (0.8.2)

24-24: src.web.v1.services.ask.AskResultResponse imported but unused

Remove unused import: src.web.v1.services.ask.AskResultResponse

(F401)


60-64: Commented-out code identified.
This block is commented out, presumably for debugging or future reference. If no longer needed, consider removing it for clarity.


213-213: Commented-out code identified.
Similar to lines 60-64, if no future usage is planned, consider removing the deploy_model reference for clarity.


355-364: Helper method for dictionary visualization.
The dict_to_string method is a neat utility. Consider adding small checks for cyclical references if there's any chance of them. Otherwise, looks good.


418-434: Polling for ask result.
The loop checks for "finished" or "failed" states; consider adding a max-retry or timeout to avoid an infinite loop in case of unexpected status.

 while (
     ask_result_response.status != "finished"
     and ask_result_response.status != "failed"
 ):
+    # Possibly add a counter or timeout check to prevent infinite loops.
     ask_result_response = self.ask_service_var.get_ask_result(
         ...
     )

435-435: Commented-out retrieval code.
If retrieval logic is not needed for the AskPipeline, consider removing it to avoid confusion.


445-445: Retrieval context commented out.
Same reasoning as above; remove if it’s no longer part of the pipeline flow.

wren-ai-service/eval/prediction.py (2)

24-27: Partially unused import.
create_service_container is imported but never used. create_service_metadata is utilized at line 189. Consider removing create_service_container.

 from src.globals import (
-    create_service_container,
     create_service_metadata,
 )
🧰 Tools
🪛 Ruff (0.8.2)

25-25: src.globals.create_service_container imported but unused

Remove unused import: src.globals.create_service_container

(F401)


91-95: Commented-out repository checks.
If these checks are no longer needed, consider removing them. If they're still useful, add a quick comment explaining why they're disabled.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8933479 and 61de6c8.

📒 Files selected for processing (3)
  • wren-ai-service/eval/pipelines.py (7 hunks)
  • wren-ai-service/eval/prediction.py (5 hunks)
  • wren-ai-service/src/globals.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
wren-ai-service/eval/prediction.py

25-25: src.globals.create_service_container imported but unused

Remove unused import: src.globals.create_service_container

(F401)


145-145: Undefined name provider

(F821)

wren-ai-service/eval/pipelines.py

10-10: orjson imported but unused

Remove unused import: orjson

(F401)


15-15: src.config.settings imported but unused

Remove unused import: src.config.settings

(F401)


16-16: src.providers.generate_components imported but unused

Remove unused import: src.providers.generate_components

(F401)


24-24: src.web.v1.services.ask.AskResultResponse imported but unused

Remove unused import: src.web.v1.services.ask.AskResultResponse

(F401)

🔇 Additional comments (25)
wren-ai-service/eval/pipelines.py (16)

5-5: Import usage is valid.
The uuid import is used to generate a unique project_id at line 383.


11-11: Import usage is valid.
The json import is used to serialize mdl at line 384.


27-32: Imports are used properly.
The pipeline modules (data_assistance, intent_classification, etc.) are referenced within the ask_service method.


57-57: Import usage is valid.
indexing import is used for referencing indexing.DBSchema, indexing.HistoricalQuestion, and indexing.TableDescription.


131-131: No actionable changes.
This line is a comment referencing how the subclass calls __call__.


313-313: Empty line.
No issues here.


315-330: Indexing service definition looks good.
The indexing_service method correctly initializes a SemanticsPreparationService using the component definitions from pipe_components.


331-354: Ask service definition is well-structured.
The ask_service method constructs an AskService instance with properly mapped pipeline steps. This design cleanly separates the pipeline logic.


369-370: Constructor parameters expanded.
Accepting service_metadata and pipe_components broadens the pipeline’s configurability.


373-374: Assignment is straightforward.
Storing service_metadata for later usage. No issues.


382-386: Initialization logic is correct.
Assigning pipeline components, generating IDs, and creating service instances ensures the pipeline is fully prepared.


388-388: Hash generation.
Storing mdl_hash by hashing the JSON string is useful for checks. No issues spotted.


398-407: Preparing semantics step.
The asynchronous call to prepare_semantics uses the correct SemanticsPreparationRequest arguments. Properly references service_metadata.


409-416: Ask request construction.
The AskRequest object is built properly. The logic to generate a query_id is consistent with the pipeline’s flow.


485-486: Function signature extended.
init now includes service_metadata and pipe_components; the new approach is consistent with the constructor changes in AskPipeline.


488-488: Collecting arguments.
Packaging the arguments into a dictionary is a clean pattern for passing them to the pipeline classes.

wren-ai-service/eval/prediction.py (8)

17-18: Valid usage.
Both settings and generate_components are utilized later in this file.


53-56: Settings-based defaults.
Using settings for these indexing and retrieval parameters is consistent with the rest of the codebase.


147-150: Provider outputs assigned properly.
These assignments look fine, assuming the undefined provider issue is resolved.


182-182: Generating pipeline components.
generate_components(settings.components) is a straightforward approach for building the pipeline modules.


187-187: Keyword unpacking usage.
The code unpacks a portion of pipe_components["db_schema_retrieval"] to pass into generate_meta. Ensure that the function indeed accepts those keys.


189-189: Service metadata creation.
The call to create_service_metadata(pipe_components) is consistent with the pipeline’s new constructor design.


194-195: Passing new arguments to pipeline init.
Forwarding service_metadata and pipe_components to match the revised signatures.


198-198: Partial prediction usage.
Only the first item in dataset["eval_dataset"] is predicted. Make sure this behavior is intentional if you expected a full batch run.

wren-ai-service/src/globals.py (1)

70-76: get method in ServiceMetadata.
This new method is straightforward and matches typical dictionary-like access patterns. If get usage grows beyond service_version and pipes_metadata, consider implementing a more robust lookup or typed property approach.

@@ -138,12 +142,12 @@ def init_providers(mdl: dict) -> dict:
if engine_config is None:
raise ValueError("Invalid datasource")

providers = provider.init_providers(engine_config=engine_config)
providers_inner = provider.init_providers(engine_config=engine_config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Undefined name provider.
This call to provider.init_providers will fail unless provider is imported or defined. Fix this reference.

- providers_inner = provider.init_providers(engine_config=engine_config)
+ from src.providers import provider
+ providers_inner = provider.init_providers(engine_config=engine_config)

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 Ruff (0.8.2)

145-145: Undefined name provider

(F821)

@cyyeh cyyeh changed the title Try to make evaluation work again chore(wren-ai-service): Try to make evaluation work again Jan 6, 2025
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 61de6c8 and 3d1b974.

📒 Files selected for processing (3)
  • wren-ai-service/eval/pipelines.py (7 hunks)
  • wren-ai-service/eval/prediction.py (4 hunks)
  • wren-ai-service/src/pipelines/common.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
wren-ai-service/eval/pipelines.py

10-10: orjson imported but unused

Remove unused import: orjson

(F401)


15-15: src.config.settings imported but unused

Remove unused import: src.config.settings

(F401)


16-16: src.providers.generate_components imported but unused

Remove unused import: src.providers.generate_components

(F401)


24-24: src.web.v1.services.ask.AskResultResponse imported but unused

Remove unused import: src.web.v1.services.ask.AskResultResponse

(F401)

wren-ai-service/eval/prediction.py

25-25: src.globals.create_service_container imported but unused

Remove unused import: src.globals.create_service_container

(F401)


145-145: Undefined name provider

(F821)

🔇 Additional comments (1)
wren-ai-service/eval/prediction.py (1)

145-150: ⚠️ Potential issue

Fix undefined reference to provider module.

The code references an undefined provider module. This will cause a NameError at runtime.

Apply this diff to fix the undefined reference:

-    providers_inner = provider.init_providers(engine_config=engine_config)
+    from src.core.provider import init_providers
+    providers_inner = init_providers(engine_config=engine_config)

Likely invalid or redundant comment.

🧰 Tools
🪛 Ruff (0.8.2)

145-145: Undefined name provider

(F821)

@@ -166,7 +166,7 @@ async def _task(result: Dict[str, str]):

if no_error:
status, _, addition = await self._engine.execute_sql(
quoted_sql, session, project_id=project_id
quoted_sql, session, project_id=int(project_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

🛠️ Refactor suggestion

Unsafe type casting of project_id needs revision

The type casting of project_id to integer is problematic because:

  • It's inconsistent with the codebase where project_id is handled as an optional string
  • It can raise ValueError for None values, which are valid according to the function signatures

Suggested fix:

quoted_sql, session, project_id=str(project_id) if project_id is not None else None
🔗 Analysis chain

Verify project_id type casting and add error handling.

The change to cast project_id to integer might fix the evaluation error, but it could raise a ValueError if project_id is None or not a valid integer string.

Add error handling to safely handle invalid project_id values:

-                    quoted_sql, session, project_id=int(project_id)
+                    quoted_sql, session, project_id=int(project_id) if project_id is not None else None
🏁 Scripts executed

The following scripts were executed for the analysis:

Error: No shell scripts found


Script:

#!/bin/bash
# Find execute_sql function definition and its usage
ast-grep --pattern 'def execute_sql($$$)'

# Search for project_id parameter in SQL execution contexts
rg "execute_sql.*project_id" -A 2

# Look for None checks on project_id
rg "project_id.*None" -A 2
rg "if.*project_id" -A 2

Length of output: 17594

Comment on lines +355 to +363
def dict_to_string(self, d: dict) -> str:
if not isinstance(d, dict):
return str(d)

result = "{"
for key, value in d.items():
result += f"'{key}': {self.dict_to_string(value)}, "
result = result.rstrip(", ") + "}"
return result
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling for circular references and unhashable keys.

The recursive dictionary to string conversion could fail in several scenarios:

  1. Circular references could cause infinite recursion
  2. Unhashable dictionary keys could raise TypeError

Add error handling to handle these cases:

     def dict_to_string(self, d: dict, seen=None) -> str:
+        if seen is None:
+            seen = set()
+
         if not isinstance(d, dict):
             return str(d)
 
+        # Check for circular references
+        d_id = id(d)
+        if d_id in seen:
+            return "{...}"  # Indicate circular reference
+        seen.add(d_id)
+
         result = "{"
-        for key, value in d.items():
-            result += f"'{key}': {self.dict_to_string(value)}, "
+        try:
+            for key, value in d.items():
+                result += f"'{key}': {self.dict_to_string(value, seen)}, "
+        except TypeError as e:
+            return f"{{Error: {str(e)}}}"
+
         result = result.rstrip(", ") + "}"
+        seen.remove(d_id)
         return result
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def dict_to_string(self, d: dict) -> str:
if not isinstance(d, dict):
return str(d)
result = "{"
for key, value in d.items():
result += f"'{key}': {self.dict_to_string(value)}, "
result = result.rstrip(", ") + "}"
return result
def dict_to_string(self, d: dict, seen=None) -> str:
if seen is None:
seen = set()
if not isinstance(d, dict):
return str(d)
# Check for circular references
d_id = id(d)
if d_id in seen:
return "{...}" # Indicate circular reference
seen.add(d_id)
result = "{"
try:
for key, value in d.items():
result += f"'{key}': {self.dict_to_string(value, seen)}, "
except TypeError as e:
return f"{{Error: {str(e)}}}"
result = result.rstrip(", ") + "}"
seen.remove(d_id)
return result

Comment on lines +331 to +354
def ask_service(self):

return AskService(
{
"intent_classification": intent_classification.IntentClassification(
**self.pipe_components["intent_classification"],
),
"data_assistance": data_assistance.DataAssistance(
**self.pipe_components["data_assistance"],
),
"retrieval": retrieval.Retrieval(
**self.pipe_components["db_schema_retrieval"],
),
"historical_question": historical_question.HistoricalQuestion(
**self.pipe_components["historical_question_retrieval"],
),
"sql_generation": sql_generation.SQLGeneration(
**self.pipe_components["sql_generation"],
),
"sql_correction": sql_correction.SQLCorrection(
**self.pipe_components["sql_correction"],
),
}
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling for missing components.

Similar to indexing_service, this function accesses self.pipe_components dictionary without checking if required components exist.

Add error handling to safely access required components:

     def ask_service(self):
+        required_components = [
+            "intent_classification",
+            "data_assistance",
+            "db_schema_retrieval",
+            "historical_question_retrieval",
+            "sql_generation",
+            "sql_correction"
+        ]
+        missing_components = [comp for comp in required_components if comp not in self.pipe_components]
+        if missing_components:
+            raise ValueError(f"Missing required components: {missing_components}")
+
         return AskService(
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def ask_service(self):
return AskService(
{
"intent_classification": intent_classification.IntentClassification(
**self.pipe_components["intent_classification"],
),
"data_assistance": data_assistance.DataAssistance(
**self.pipe_components["data_assistance"],
),
"retrieval": retrieval.Retrieval(
**self.pipe_components["db_schema_retrieval"],
),
"historical_question": historical_question.HistoricalQuestion(
**self.pipe_components["historical_question_retrieval"],
),
"sql_generation": sql_generation.SQLGeneration(
**self.pipe_components["sql_generation"],
),
"sql_correction": sql_correction.SQLCorrection(
**self.pipe_components["sql_correction"],
),
}
)
def ask_service(self):
required_components = [
"intent_classification",
"data_assistance",
"db_schema_retrieval",
"historical_question_retrieval",
"sql_generation",
"sql_correction"
]
missing_components = [comp for comp in required_components if comp not in self.pipe_components]
if missing_components:
raise ValueError(f"Missing required components: {missing_components}")
return AskService(
{
"intent_classification": intent_classification.IntentClassification(
**self.pipe_components["intent_classification"],
),
"data_assistance": data_assistance.DataAssistance(
**self.pipe_components["data_assistance"],
),
"retrieval": retrieval.Retrieval(
**self.pipe_components["db_schema_retrieval"],
),
"historical_question": historical_question.HistoricalQuestion(
**self.pipe_components["historical_question_retrieval"],
),
"sql_generation": sql_generation.SQLGeneration(
**self.pipe_components["sql_generation"],
),
"sql_correction": sql_correction.SQLCorrection(
**self.pipe_components["sql_correction"],
),
}
)

Comment on lines +315 to +329
def indexing_service(self):

return SemanticsPreparationService(
{
"db_schema": indexing.DBSchema(
**self.pipe_components["db_schema_indexing"],
),
"historical_question": indexing.HistoricalQuestion(
**self.pipe_components["historical_question_indexing"],
),
"table_description": indexing.TableDescription(
**self.pipe_components["table_description_indexing"],
),
}
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling for missing components.

The function accesses self.pipe_components dictionary without checking if required components exist. This could raise KeyError exceptions if any required component is missing.

Add error handling to safely access required components:

     def indexing_service(self):
+        required_components = [
+            "db_schema_indexing",
+            "historical_question_indexing",
+            "table_description_indexing"
+        ]
+        missing_components = [comp for comp in required_components if comp not in self.pipe_components]
+        if missing_components:
+            raise ValueError(f"Missing required components: {missing_components}")
+
         return SemanticsPreparationService(
             {
                 "db_schema": indexing.DBSchema(
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def indexing_service(self):
return SemanticsPreparationService(
{
"db_schema": indexing.DBSchema(
**self.pipe_components["db_schema_indexing"],
),
"historical_question": indexing.HistoricalQuestion(
**self.pipe_components["historical_question_indexing"],
),
"table_description": indexing.TableDescription(
**self.pipe_components["table_description_indexing"],
),
}
)
def indexing_service(self):
required_components = [
"db_schema_indexing",
"historical_question_indexing",
"table_description_indexing"
]
missing_components = [comp for comp in required_components if comp not in self.pipe_components]
if missing_components:
raise ValueError(f"Missing required components: {missing_components}")
return SemanticsPreparationService(
{
"db_schema": indexing.DBSchema(
**self.pipe_components["db_schema_indexing"],
),
"historical_question": indexing.HistoricalQuestion(
**self.pipe_components["historical_question_indexing"],
),
"table_description": indexing.TableDescription(
**self.pipe_components["table_description_indexing"],
),
}
)

Comment on lines +398 to +445

await self.indexing_service_var.prepare_semantics(
SemanticsPreparationRequest(
mdl=self.mdl_str_var,
mdl_hash=self.mdl_hash,
project_id=self.project_id
),
service_metadata=self.service_metadata,
)

# asking
ask_request = AskRequest(
query=prediction["input"],
contexts=documents,
samples=prediction["samples"],
exclude=[],
mdl_hash=self.mdl_hash,
project_id = self.project_id,

)
ask_request.query_id = str(uuid.uuid4().int >> 65)
await self.ask_service_var.ask(ask_request, service_metadata=self.service_metadata)
# getting ask result
ask_result_response = self.ask_service_var.get_ask_result(
AskResultRequest(
query_id=ask_request.query_id,
)
)

prediction["actual_output"] = actual_output
prediction["retrieval_context"] = extract_units(documents)
while (
ask_result_response.status != "finished"
and ask_result_response.status != "failed"
):
# getting ask result
ask_result_response = self.ask_service_var.get_ask_result(
AskResultRequest(
query_id=ask_request.query_id,
)
)

# result = await self._retrieval.run(query=prediction["input"])
# documents = result.get("construct_retrieval_results", [])
# actual_output = await self._generation.run(
# query=prediction["input"],
# contexts=documents,
# samples=prediction["samples"],
# exclude=[],
# )

prediction["actual_output"] = ask_result_response.response[0].sql
#prediction["retrieval_context"] = extract_units(documents)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add timeout, backoff, and error handling to the polling loop.

The current implementation has several issues:

  1. No timeout mechanism for the polling loop
  2. No backoff strategy between retries
  3. No error handling for service calls

Add these safety mechanisms:

     async def _process(self, prediction: dict, **_) -> dict:
+        MAX_RETRIES = 10
+        INITIAL_BACKOFF = 1  # seconds
+        MAX_BACKOFF = 32  # seconds
+
+        try:
             await self.indexing_service_var.prepare_semantics(
                 SemanticsPreparationRequest(
                     mdl=self.mdl_str_var,
                     mdl_hash=self.mdl_hash,
                     project_id=self.project_id
                 ),
                 service_metadata=self.service_metadata,
             )
 
             ask_request = AskRequest(
                 query=prediction["input"],
                 mdl_hash=self.mdl_hash,
                 project_id=self.project_id,
             )
             ask_request.query_id = str(uuid.uuid4().int >> 65)
             await self.ask_service_var.ask(ask_request, service_metadata=self.service_metadata)
 
             ask_result_response = self.ask_service_var.get_ask_result(
                 AskResultRequest(
                     query_id=ask_request.query_id,
                 )
             )
 
+            retries = 0
+            backoff = INITIAL_BACKOFF
             while (
                 ask_result_response.status != "finished"
                 and ask_result_response.status != "failed"
+                and retries < MAX_RETRIES
             ):
+                await asyncio.sleep(backoff)
+                backoff = min(backoff * 2, MAX_BACKOFF)
+                retries += 1
+
                 ask_result_response = self.ask_service_var.get_ask_result(
                     AskResultRequest(
                         query_id=ask_request.query_id,
                     )
                 )
 
+            if retries >= MAX_RETRIES:
+                raise TimeoutError("Ask service request timed out")
+
+            if ask_result_response.status == "failed":
+                raise RuntimeError(f"Ask service request failed: {ask_result_response.error}")
+
             prediction["actual_output"] = ask_result_response.response[0].sql
+        except Exception as e:
+            logger.exception("Error in _process: %s", str(e))
+            prediction["actual_output"] = None
+            prediction["error"] = str(e)
 
         return prediction
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
await self.indexing_service_var.prepare_semantics(
SemanticsPreparationRequest(
mdl=self.mdl_str_var,
mdl_hash=self.mdl_hash,
project_id=self.project_id
),
service_metadata=self.service_metadata,
)
# asking
ask_request = AskRequest(
query=prediction["input"],
contexts=documents,
samples=prediction["samples"],
exclude=[],
mdl_hash=self.mdl_hash,
project_id = self.project_id,
)
ask_request.query_id = str(uuid.uuid4().int >> 65)
await self.ask_service_var.ask(ask_request, service_metadata=self.service_metadata)
# getting ask result
ask_result_response = self.ask_service_var.get_ask_result(
AskResultRequest(
query_id=ask_request.query_id,
)
)
prediction["actual_output"] = actual_output
prediction["retrieval_context"] = extract_units(documents)
while (
ask_result_response.status != "finished"
and ask_result_response.status != "failed"
):
# getting ask result
ask_result_response = self.ask_service_var.get_ask_result(
AskResultRequest(
query_id=ask_request.query_id,
)
)
# result = await self._retrieval.run(query=prediction["input"])
# documents = result.get("construct_retrieval_results", [])
# actual_output = await self._generation.run(
# query=prediction["input"],
# contexts=documents,
# samples=prediction["samples"],
# exclude=[],
# )
prediction["actual_output"] = ask_result_response.response[0].sql
#prediction["retrieval_context"] = extract_units(documents)
async def _process(self, prediction: dict, **_) -> dict:
MAX_RETRIES = 10
INITIAL_BACKOFF = 1 # seconds
MAX_BACKOFF = 32 # seconds
try:
await self.indexing_service_var.prepare_semantics(
SemanticsPreparationRequest(
mdl=self.mdl_str_var,
mdl_hash=self.mdl_hash,
project_id=self.project_id
),
service_metadata=self.service_metadata,
)
ask_request = AskRequest(
query=prediction["input"],
mdl_hash=self.mdl_hash,
project_id=self.project_id,
)
ask_request.query_id = str(uuid.uuid4().int >> 65)
await self.ask_service_var.ask(ask_request, service_metadata=self.service_metadata)
ask_result_response = self.ask_service_var.get_ask_result(
AskResultRequest(
query_id=ask_request.query_id,
)
)
retries = 0
backoff = INITIAL_BACKOFF
while (
ask_result_response.status != "finished"
and ask_result_response.status != "failed"
and retries < MAX_RETRIES
):
await asyncio.sleep(backoff)
backoff = min(backoff * 2, MAX_BACKOFF)
retries += 1
ask_result_response = self.ask_service_var.get_ask_result(
AskResultRequest(
query_id=ask_request.query_id,
)
)
if retries >= MAX_RETRIES:
raise TimeoutError("Ask service request timed out")
if ask_result_response.status == "failed":
raise RuntimeError(f"Ask service request failed: {ask_result_response.error}")
prediction["actual_output"] = ask_result_response.response[0].sql
except Exception as e:
logger.exception("Error in _process: %s", str(e))
prediction["actual_output"] = None
prediction["error"] = str(e)
return prediction

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
wren-ai-service/eval/pipelines.py (2)

375-380: Remove or document commented code.

There are multiple lines of commented code without explanation. Either remove them if they're no longer needed or add a comment explaining why they're preserved.

-        # document_store_provider.get_store(recreate_index=True)
-        # _indexing = indexing.Indexing(
-        #     embedder_provider=embedder_provider,
-        #     document_store_provider=document_store_provider,
-        # )
-        # deploy_model(mdl, _indexing)
+        # TODO: Remove after confirming new component-based indexing is stable
+        # Legacy indexing code kept for reference:
+        # document_store_provider.get_store(recreate_index=True)
+        # _indexing = indexing.Indexing(
+        #     embedder_provider=embedder_provider,
+        #     document_store_provider=document_store_provider,
+        # )
+        # deploy_model(mdl, _indexing)

485-488: Add type hints for service_metadata parameter.

The service_metadata parameter lacks type hints, which could lead to type-related issues.

 def init(
     name: Literal["retrieval", "generation", "ask"],
     meta: dict,
     mdl: dict,
-    service_metadata,
+    service_metadata: Dict[str, Any],
     pipe_components: Dict[str, Any],
 ) -> Eval:
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3d1b974 and 2c7ab2c.

📒 Files selected for processing (2)
  • wren-ai-service/eval/pipelines.py (7 hunks)
  • wren-ai-service/eval/prediction.py (4 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
wren-ai-service/eval/prediction.py

25-25: src.globals.create_service_container imported but unused

Remove unused import: src.globals.create_service_container

(F401)


145-145: Undefined name provider

(F821)

wren-ai-service/eval/pipelines.py

10-10: orjson imported but unused

Remove unused import: orjson

(F401)


15-15: src.config.settings imported but unused

Remove unused import: src.config.settings

(F401)


16-16: src.providers.generate_components imported but unused

Remove unused import: src.providers.generate_components

(F401)


24-24: src.web.v1.services.ask.AskResultResponse imported but unused

Remove unused import: src.web.v1.services.ask.AskResultResponse

(F401)

🔇 Additional comments (5)
wren-ai-service/eval/prediction.py (2)

53-56: LGTM! Configuration management improvement.

Good refactor to use centralized settings with fallback values.


17-27: ⚠️ Potential issue

Fix missing provider import causing runtime error.

The code attempts to use provider.init_providers on line 145, but the provider module is no longer imported.

Add the following import:

+from src.providers import provider

Likely invalid or redundant comment.

🧰 Tools
🪛 Ruff (0.8.2)

25-25: src.globals.create_service_container imported but unused

Remove unused import: src.globals.create_service_container

(F401)

wren-ai-service/eval/pipelines.py (3)

355-363: 🛠️ Refactor suggestion

Add protection against circular references and unhashable keys.

The dictionary to string conversion is vulnerable to:

  1. Infinite recursion from circular references
  2. TypeError from unhashable dictionary keys

Add these safety mechanisms:

-    def dict_to_string(self, d: dict) -> str:
+    def dict_to_string(self, d: dict, seen=None) -> str:
+        if seen is None:
+            seen = set()
+            
         if not isinstance(d, dict):
             return str(d)
             
+        # Check for circular references
+        d_id = id(d)
+        if d_id in seen:
+            return "{...}"  # Indicate circular reference
+        seen.add(d_id)
         
         result = "{"
-        for key, value in d.items():
-            result += f"'{key}': {self.dict_to_string(value)}, "
+        try:
+            for key, value in d.items():
+                result += f"'{key}': {self.dict_to_string(value, seen)}, "
+        except TypeError as e:
+            return f"{{Error: {str(e)}}}"
         
         result = result.rstrip(", ") + "}"
+        seen.remove(d_id)
         return result

Likely invalid or redundant comment.


364-388: 🛠️ Refactor suggestion

Add validation for required components.

The constructor should validate the presence of all required components before initialization to prevent runtime errors.

Add component validation:

     def __init__(
         self,
         meta: dict,
         mdl: dict,
         service_metadata,
         pipe_components,
     ):
         super().__init__(meta, 3)
+        required_components = {
+            'db_schema_indexing',
+            'historical_question_indexing',
+            'table_description_indexing',
+            'intent_classification',
+            'data_assistance',
+            'db_schema_retrieval',
+            'historical_question_retrieval',
+            'sql_generation',
+            'sql_correction'
+        }
+        
+        missing = required_components - set(pipe_components.keys())
+        if missing:
+            raise ValueError(f"Missing required components: {missing}")
+            
         self.service_metadata = service_metadata
         self.pipe_components = pipe_components
         self.project_id = str(uuid.uuid4().int >> 65)

Likely invalid or redundant comment.


398-445: ⚠️ Potential issue

Add timeout and error handling to the polling loop.

The current polling implementation lacks several critical safety mechanisms:

  1. No timeout to prevent infinite loops
  2. No delay between retries
  3. No error handling for failed status

Add these safety mechanisms:

     async def _process(self, prediction: dict, **_) -> dict:
+        MAX_RETRIES = 10
+        RETRY_DELAY = 1  # seconds
+        
+        try:
             await self.indexing_service_var.prepare_semantics(
                 SemanticsPreparationRequest(
                     mdl=self.mdl_str_var,
                     mdl_hash=self.mdl_hash,
                     project_id=self.project_id
                 ),
                 service_metadata=self.service_metadata,
             )

             ask_request = AskRequest(
                 query=prediction["input"],
                 mdl_hash=self.mdl_hash,
                 project_id=self.project_id,
             )
             ask_request.query_id = str(uuid.uuid4().int >> 65)
             await self.ask_service_var.ask(ask_request, service_metadata=self.service_metadata)

             ask_result_response = self.ask_service_var.get_ask_result(
                 AskResultRequest(
                     query_id=ask_request.query_id,
                 )
             )

+            retries = 0
             while (
                 ask_result_response.status != "finished"
                 and ask_result_response.status != "failed"
+                and retries < MAX_RETRIES
             ):
+                await asyncio.sleep(RETRY_DELAY)
+                retries += 1
                 ask_result_response = self.ask_service_var.get_ask_result(
                     AskResultRequest(
                         query_id=ask_request.query_id,
                     )
                 )

+            if retries >= MAX_RETRIES:
+                raise TimeoutError(f"Ask service request timed out after {MAX_RETRIES} retries")
+            
+            if ask_result_response.status == "failed":
+                raise RuntimeError(f"Ask service request failed: {ask_result_response.error}")
+            
+            if not ask_result_response.response:
+                raise ValueError("Empty response from ask service")
+                
             prediction["actual_output"] = ask_result_response.response[0].sql
+        except Exception as e:
+            prediction["error"] = str(e)
+            prediction["actual_output"] = None

         return prediction

Likely invalid or redundant comment.

Comment on lines +182 to +195
pipe_components = generate_components(settings.components)
meta = generate_meta(
path=path,
dataset=dataset,
pipe=pipe_name,
**providers,
**pipe_components["db_schema_retrieval"],
)

service_metadata = create_service_metadata(pipe_components)
pipe = pipelines.init(
pipe_name,
meta,
mdl=dataset["mdl"],
providers=providers,
service_metadata=service_metadata,
pipe_components=pipe_components,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling for component generation.

The component generation and service metadata creation could fail silently. Consider adding error handling:

-    pipe_components = generate_components(settings.components)
-    meta = generate_meta(
-        path=path,
-        dataset=dataset,
-        pipe=pipe_name,
-        **pipe_components["db_schema_retrieval"],
-    )
-    service_metadata = create_service_metadata(pipe_components)
+    try:
+        pipe_components = generate_components(settings.components)
+        if not pipe_components.get("db_schema_retrieval"):
+            raise ValueError("Required component 'db_schema_retrieval' not found")
+        
+        meta = generate_meta(
+            path=path,
+            dataset=dataset,
+            pipe=pipe_name,
+            **pipe_components["db_schema_retrieval"],
+        )
+        service_metadata = create_service_metadata(pipe_components)
+    except Exception as e:
+        raise RuntimeError(f"Failed to initialize components: {str(e)}") from e
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
pipe_components = generate_components(settings.components)
meta = generate_meta(
path=path,
dataset=dataset,
pipe=pipe_name,
**providers,
**pipe_components["db_schema_retrieval"],
)
service_metadata = create_service_metadata(pipe_components)
pipe = pipelines.init(
pipe_name,
meta,
mdl=dataset["mdl"],
providers=providers,
service_metadata=service_metadata,
pipe_components=pipe_components,
try:
pipe_components = generate_components(settings.components)
if not pipe_components.get("db_schema_retrieval"):
raise ValueError("Required component 'db_schema_retrieval' not found")
meta = generate_meta(
path=path,
dataset=dataset,
pipe=pipe_name,
**pipe_components["db_schema_retrieval"],
)
service_metadata = create_service_metadata(pipe_components)
except Exception as e:
raise RuntimeError(f"Failed to initialize components: {str(e)}") from e
pipe = pipelines.init(
pipe_name,
meta,
mdl=dataset["mdl"],
service_metadata=service_metadata,
pipe_components=pipe_components,

@paopa paopa requested review from cyyeh and paopa January 14, 2025 08:24
@paopa paopa removed request for cyyeh and paopa January 21, 2025 08:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants