-
Notifications
You must be signed in to change notification settings - Fork 430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
chore(wren-ai-service): Try to make evaluation work again #1085
base: main
Are you sure you want to change the base?
Conversation
Warning Rate limit exceeded@tedyyan has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 5 minutes and 49 seconds before requesting another review. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. 📒 Files selected for processing (3)
WalkthroughThe pull request introduces substantial modifications to the AI service's evaluation and prediction pipelines. Key changes include the addition of the Changes
Sequence DiagramsequenceDiagram
participant User
participant AskPipeline
participant IndexingService
participant AskService
User->>AskPipeline: Initialize with metadata
AskPipeline->>IndexingService: Prepare semantics
IndexingService-->>AskPipeline: Semantics ready
AskPipeline->>AskService: Create AskRequest
AskService-->>AskPipeline: Process query
AskPipeline-->>User: Return SQL output
Poem
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (11)
wren-ai-service/eval/pipelines.py (9)
15-15
: Unused import detected.
Thesettings
import doesn't appear to be used anywhere in this file. Consider removing it to keep the imports clean.- from src.config import settings
🧰 Tools
🪛 Ruff (0.8.2)
15-15:
src.config.settings
imported but unusedRemove unused import:
src.config.settings
(F401)
16-16
: Unused import detected.
Thegenerate_components
import is unused in this file. Consider removing it to reduce clutter.- from src.providers import generate_components
🧰 Tools
🪛 Ruff (0.8.2)
16-16:
src.providers.generate_components
imported but unusedRemove unused import:
src.providers.generate_components
(F401)
17-26
: Potentially redundant import.
All imported symbols here (AskRequest
,AskResultRequest
,AskResultResponse
,AskService
, etc.) are used except possiblyAskResultResponse
. Since the code callsask_service_var.get_ask_result(...)
(which returns an object of typeAskResultResponse
), you may keep it for clarity or for type hints. Otherwise, if it's not referenced directly (e.g., in annotations), removing it can reduce confusion.🧰 Tools
🪛 Ruff (0.8.2)
24-24:
src.web.v1.services.ask.AskResultResponse
imported but unusedRemove unused import:
src.web.v1.services.ask.AskResultResponse
(F401)
60-64
: Commented-out code identified.
This block is commented out, presumably for debugging or future reference. If no longer needed, consider removing it for clarity.
213-213
: Commented-out code identified.
Similar to lines 60-64, if no future usage is planned, consider removing thedeploy_model
reference for clarity.
355-364
: Helper method for dictionary visualization.
Thedict_to_string
method is a neat utility. Consider adding small checks for cyclical references if there's any chance of them. Otherwise, looks good.
418-434
: Polling for ask result.
The loop checks for"finished"
or"failed"
states; consider adding a max-retry or timeout to avoid an infinite loop in case of unexpected status.while ( ask_result_response.status != "finished" and ask_result_response.status != "failed" ): + # Possibly add a counter or timeout check to prevent infinite loops. ask_result_response = self.ask_service_var.get_ask_result( ... )
435-435
: Commented-out retrieval code.
If retrieval logic is not needed for theAskPipeline
, consider removing it to avoid confusion.
445-445
: Retrieval context commented out.
Same reasoning as above; remove if it’s no longer part of the pipeline flow.wren-ai-service/eval/prediction.py (2)
24-27
: Partially unused import.
create_service_container
is imported but never used.create_service_metadata
is utilized at line 189. Consider removingcreate_service_container
.from src.globals import ( - create_service_container, create_service_metadata, )
🧰 Tools
🪛 Ruff (0.8.2)
25-25:
src.globals.create_service_container
imported but unusedRemove unused import:
src.globals.create_service_container
(F401)
91-95
: Commented-out repository checks.
If these checks are no longer needed, consider removing them. If they're still useful, add a quick comment explaining why they're disabled.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
wren-ai-service/eval/pipelines.py
(7 hunks)wren-ai-service/eval/prediction.py
(5 hunks)wren-ai-service/src/globals.py
(1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
wren-ai-service/eval/prediction.py
25-25: src.globals.create_service_container
imported but unused
Remove unused import: src.globals.create_service_container
(F401)
145-145: Undefined name provider
(F821)
wren-ai-service/eval/pipelines.py
10-10: orjson
imported but unused
Remove unused import: orjson
(F401)
15-15: src.config.settings
imported but unused
Remove unused import: src.config.settings
(F401)
16-16: src.providers.generate_components
imported but unused
Remove unused import: src.providers.generate_components
(F401)
24-24: src.web.v1.services.ask.AskResultResponse
imported but unused
Remove unused import: src.web.v1.services.ask.AskResultResponse
(F401)
🔇 Additional comments (25)
wren-ai-service/eval/pipelines.py (16)
5-5
: Import usage is valid.
Theuuid
import is used to generate a uniqueproject_id
at line 383.
11-11
: Import usage is valid.
Thejson
import is used to serializemdl
at line 384.
27-32
: Imports are used properly.
The pipeline modules (data_assistance
,intent_classification
, etc.) are referenced within theask_service
method.
57-57
: Import usage is valid.
indexing
import is used for referencingindexing.DBSchema
,indexing.HistoricalQuestion
, andindexing.TableDescription
.
131-131
: No actionable changes.
This line is a comment referencing how the subclass calls__call__
.
313-313
: Empty line.
No issues here.
315-330
: Indexing service definition looks good.
Theindexing_service
method correctly initializes aSemanticsPreparationService
using the component definitions frompipe_components
.
331-354
: Ask service definition is well-structured.
Theask_service
method constructs anAskService
instance with properly mapped pipeline steps. This design cleanly separates the pipeline logic.
369-370
: Constructor parameters expanded.
Acceptingservice_metadata
andpipe_components
broadens the pipeline’s configurability.
373-374
: Assignment is straightforward.
Storingservice_metadata
for later usage. No issues.
382-386
: Initialization logic is correct.
Assigning pipeline components, generating IDs, and creating service instances ensures the pipeline is fully prepared.
388-388
: Hash generation.
Storingmdl_hash
by hashing the JSON string is useful for checks. No issues spotted.
398-407
: Preparing semantics step.
The asynchronous call toprepare_semantics
uses the correctSemanticsPreparationRequest
arguments. Properly referencesservice_metadata
.
409-416
: Ask request construction.
TheAskRequest
object is built properly. The logic to generate aquery_id
is consistent with the pipeline’s flow.
485-486
: Function signature extended.
init
now includesservice_metadata
andpipe_components
; the new approach is consistent with the constructor changes inAskPipeline
.
488-488
: Collecting arguments.
Packaging the arguments into a dictionary is a clean pattern for passing them to the pipeline classes.wren-ai-service/eval/prediction.py (8)
17-18
: Valid usage.
Bothsettings
andgenerate_components
are utilized later in this file.
53-56
: Settings-based defaults.
Usingsettings
for these indexing and retrieval parameters is consistent with the rest of the codebase.
147-150
: Provider outputs assigned properly.
These assignments look fine, assuming the undefinedprovider
issue is resolved.
182-182
: Generating pipeline components.
generate_components(settings.components)
is a straightforward approach for building the pipeline modules.
187-187
: Keyword unpacking usage.
The code unpacks a portion ofpipe_components["db_schema_retrieval"]
to pass intogenerate_meta
. Ensure that the function indeed accepts those keys.
189-189
: Service metadata creation.
The call tocreate_service_metadata(pipe_components)
is consistent with the pipeline’s new constructor design.
194-195
: Passing new arguments to pipeline init.
Forwardingservice_metadata
andpipe_components
to match the revised signatures.
198-198
: Partial prediction usage.
Only the first item indataset["eval_dataset"]
is predicted. Make sure this behavior is intentional if you expected a full batch run.wren-ai-service/src/globals.py (1)
70-76
:get
method inServiceMetadata
.
This new method is straightforward and matches typical dictionary-like access patterns. Ifget
usage grows beyondservice_version
andpipes_metadata
, consider implementing a more robust lookup or typed property approach.
@@ -138,12 +142,12 @@ def init_providers(mdl: dict) -> dict: | |||
if engine_config is None: | |||
raise ValueError("Invalid datasource") | |||
|
|||
providers = provider.init_providers(engine_config=engine_config) | |||
providers_inner = provider.init_providers(engine_config=engine_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Undefined name provider
.
This call to provider.init_providers
will fail unless provider
is imported or defined. Fix this reference.
- providers_inner = provider.init_providers(engine_config=engine_config)
+ from src.providers import provider
+ providers_inner = provider.init_providers(engine_config=engine_config)
Committable suggestion skipped: line range outside the PR's diff.
🧰 Tools
🪛 Ruff (0.8.2)
145-145: Undefined name provider
(F821)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
wren-ai-service/eval/pipelines.py
(7 hunks)wren-ai-service/eval/prediction.py
(4 hunks)wren-ai-service/src/pipelines/common.py
(1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
wren-ai-service/eval/pipelines.py
10-10: orjson
imported but unused
Remove unused import: orjson
(F401)
15-15: src.config.settings
imported but unused
Remove unused import: src.config.settings
(F401)
16-16: src.providers.generate_components
imported but unused
Remove unused import: src.providers.generate_components
(F401)
24-24: src.web.v1.services.ask.AskResultResponse
imported but unused
Remove unused import: src.web.v1.services.ask.AskResultResponse
(F401)
wren-ai-service/eval/prediction.py
25-25: src.globals.create_service_container
imported but unused
Remove unused import: src.globals.create_service_container
(F401)
145-145: Undefined name provider
(F821)
🔇 Additional comments (1)
wren-ai-service/eval/prediction.py (1)
145-150
:⚠️ Potential issueFix undefined reference to
provider
module.The code references an undefined
provider
module. This will cause a NameError at runtime.Apply this diff to fix the undefined reference:
- providers_inner = provider.init_providers(engine_config=engine_config) + from src.core.provider import init_providers + providers_inner = init_providers(engine_config=engine_config)Likely invalid or redundant comment.
🧰 Tools
🪛 Ruff (0.8.2)
145-145: Undefined name
provider
(F821)
@@ -166,7 +166,7 @@ async def _task(result: Dict[str, str]): | |||
|
|||
if no_error: | |||
status, _, addition = await self._engine.execute_sql( | |||
quoted_sql, session, project_id=project_id | |||
quoted_sql, session, project_id=int(project_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codebase verification
🛠️ Refactor suggestion
Unsafe type casting of project_id needs revision
The type casting of project_id
to integer is problematic because:
- It's inconsistent with the codebase where
project_id
is handled as an optional string - It can raise ValueError for None values, which are valid according to the function signatures
Suggested fix:
quoted_sql, session, project_id=str(project_id) if project_id is not None else None
🔗 Analysis chain
Verify project_id type casting and add error handling.
The change to cast project_id
to integer might fix the evaluation error, but it could raise a ValueError if project_id
is None or not a valid integer string.
Add error handling to safely handle invalid project_id values:
- quoted_sql, session, project_id=int(project_id)
+ quoted_sql, session, project_id=int(project_id) if project_id is not None else None
🏁 Scripts executed
The following scripts were executed for the analysis:
Error: No shell scripts found
Script:
#!/bin/bash
# Find execute_sql function definition and its usage
ast-grep --pattern 'def execute_sql($$$)'
# Search for project_id parameter in SQL execution contexts
rg "execute_sql.*project_id" -A 2
# Look for None checks on project_id
rg "project_id.*None" -A 2
rg "if.*project_id" -A 2
Length of output: 17594
def dict_to_string(self, d: dict) -> str: | ||
if not isinstance(d, dict): | ||
return str(d) | ||
|
||
result = "{" | ||
for key, value in d.items(): | ||
result += f"'{key}': {self.dict_to_string(value)}, " | ||
result = result.rstrip(", ") + "}" | ||
return result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add error handling for circular references and unhashable keys.
The recursive dictionary to string conversion could fail in several scenarios:
- Circular references could cause infinite recursion
- Unhashable dictionary keys could raise TypeError
Add error handling to handle these cases:
def dict_to_string(self, d: dict, seen=None) -> str:
+ if seen is None:
+ seen = set()
+
if not isinstance(d, dict):
return str(d)
+ # Check for circular references
+ d_id = id(d)
+ if d_id in seen:
+ return "{...}" # Indicate circular reference
+ seen.add(d_id)
+
result = "{"
- for key, value in d.items():
- result += f"'{key}': {self.dict_to_string(value)}, "
+ try:
+ for key, value in d.items():
+ result += f"'{key}': {self.dict_to_string(value, seen)}, "
+ except TypeError as e:
+ return f"{{Error: {str(e)}}}"
+
result = result.rstrip(", ") + "}"
+ seen.remove(d_id)
return result
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def dict_to_string(self, d: dict) -> str: | |
if not isinstance(d, dict): | |
return str(d) | |
result = "{" | |
for key, value in d.items(): | |
result += f"'{key}': {self.dict_to_string(value)}, " | |
result = result.rstrip(", ") + "}" | |
return result | |
def dict_to_string(self, d: dict, seen=None) -> str: | |
if seen is None: | |
seen = set() | |
if not isinstance(d, dict): | |
return str(d) | |
# Check for circular references | |
d_id = id(d) | |
if d_id in seen: | |
return "{...}" # Indicate circular reference | |
seen.add(d_id) | |
result = "{" | |
try: | |
for key, value in d.items(): | |
result += f"'{key}': {self.dict_to_string(value, seen)}, " | |
except TypeError as e: | |
return f"{{Error: {str(e)}}}" | |
result = result.rstrip(", ") + "}" | |
seen.remove(d_id) | |
return result |
def ask_service(self): | ||
|
||
return AskService( | ||
{ | ||
"intent_classification": intent_classification.IntentClassification( | ||
**self.pipe_components["intent_classification"], | ||
), | ||
"data_assistance": data_assistance.DataAssistance( | ||
**self.pipe_components["data_assistance"], | ||
), | ||
"retrieval": retrieval.Retrieval( | ||
**self.pipe_components["db_schema_retrieval"], | ||
), | ||
"historical_question": historical_question.HistoricalQuestion( | ||
**self.pipe_components["historical_question_retrieval"], | ||
), | ||
"sql_generation": sql_generation.SQLGeneration( | ||
**self.pipe_components["sql_generation"], | ||
), | ||
"sql_correction": sql_correction.SQLCorrection( | ||
**self.pipe_components["sql_correction"], | ||
), | ||
} | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add error handling for missing components.
Similar to indexing_service
, this function accesses self.pipe_components
dictionary without checking if required components exist.
Add error handling to safely access required components:
def ask_service(self):
+ required_components = [
+ "intent_classification",
+ "data_assistance",
+ "db_schema_retrieval",
+ "historical_question_retrieval",
+ "sql_generation",
+ "sql_correction"
+ ]
+ missing_components = [comp for comp in required_components if comp not in self.pipe_components]
+ if missing_components:
+ raise ValueError(f"Missing required components: {missing_components}")
+
return AskService(
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def ask_service(self): | |
return AskService( | |
{ | |
"intent_classification": intent_classification.IntentClassification( | |
**self.pipe_components["intent_classification"], | |
), | |
"data_assistance": data_assistance.DataAssistance( | |
**self.pipe_components["data_assistance"], | |
), | |
"retrieval": retrieval.Retrieval( | |
**self.pipe_components["db_schema_retrieval"], | |
), | |
"historical_question": historical_question.HistoricalQuestion( | |
**self.pipe_components["historical_question_retrieval"], | |
), | |
"sql_generation": sql_generation.SQLGeneration( | |
**self.pipe_components["sql_generation"], | |
), | |
"sql_correction": sql_correction.SQLCorrection( | |
**self.pipe_components["sql_correction"], | |
), | |
} | |
) | |
def ask_service(self): | |
required_components = [ | |
"intent_classification", | |
"data_assistance", | |
"db_schema_retrieval", | |
"historical_question_retrieval", | |
"sql_generation", | |
"sql_correction" | |
] | |
missing_components = [comp for comp in required_components if comp not in self.pipe_components] | |
if missing_components: | |
raise ValueError(f"Missing required components: {missing_components}") | |
return AskService( | |
{ | |
"intent_classification": intent_classification.IntentClassification( | |
**self.pipe_components["intent_classification"], | |
), | |
"data_assistance": data_assistance.DataAssistance( | |
**self.pipe_components["data_assistance"], | |
), | |
"retrieval": retrieval.Retrieval( | |
**self.pipe_components["db_schema_retrieval"], | |
), | |
"historical_question": historical_question.HistoricalQuestion( | |
**self.pipe_components["historical_question_retrieval"], | |
), | |
"sql_generation": sql_generation.SQLGeneration( | |
**self.pipe_components["sql_generation"], | |
), | |
"sql_correction": sql_correction.SQLCorrection( | |
**self.pipe_components["sql_correction"], | |
), | |
} | |
) |
def indexing_service(self): | ||
|
||
return SemanticsPreparationService( | ||
{ | ||
"db_schema": indexing.DBSchema( | ||
**self.pipe_components["db_schema_indexing"], | ||
), | ||
"historical_question": indexing.HistoricalQuestion( | ||
**self.pipe_components["historical_question_indexing"], | ||
), | ||
"table_description": indexing.TableDescription( | ||
**self.pipe_components["table_description_indexing"], | ||
), | ||
} | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add error handling for missing components.
The function accesses self.pipe_components
dictionary without checking if required components exist. This could raise KeyError exceptions if any required component is missing.
Add error handling to safely access required components:
def indexing_service(self):
+ required_components = [
+ "db_schema_indexing",
+ "historical_question_indexing",
+ "table_description_indexing"
+ ]
+ missing_components = [comp for comp in required_components if comp not in self.pipe_components]
+ if missing_components:
+ raise ValueError(f"Missing required components: {missing_components}")
+
return SemanticsPreparationService(
{
"db_schema": indexing.DBSchema(
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def indexing_service(self): | |
return SemanticsPreparationService( | |
{ | |
"db_schema": indexing.DBSchema( | |
**self.pipe_components["db_schema_indexing"], | |
), | |
"historical_question": indexing.HistoricalQuestion( | |
**self.pipe_components["historical_question_indexing"], | |
), | |
"table_description": indexing.TableDescription( | |
**self.pipe_components["table_description_indexing"], | |
), | |
} | |
) | |
def indexing_service(self): | |
required_components = [ | |
"db_schema_indexing", | |
"historical_question_indexing", | |
"table_description_indexing" | |
] | |
missing_components = [comp for comp in required_components if comp not in self.pipe_components] | |
if missing_components: | |
raise ValueError(f"Missing required components: {missing_components}") | |
return SemanticsPreparationService( | |
{ | |
"db_schema": indexing.DBSchema( | |
**self.pipe_components["db_schema_indexing"], | |
), | |
"historical_question": indexing.HistoricalQuestion( | |
**self.pipe_components["historical_question_indexing"], | |
), | |
"table_description": indexing.TableDescription( | |
**self.pipe_components["table_description_indexing"], | |
), | |
} | |
) |
|
||
await self.indexing_service_var.prepare_semantics( | ||
SemanticsPreparationRequest( | ||
mdl=self.mdl_str_var, | ||
mdl_hash=self.mdl_hash, | ||
project_id=self.project_id | ||
), | ||
service_metadata=self.service_metadata, | ||
) | ||
|
||
# asking | ||
ask_request = AskRequest( | ||
query=prediction["input"], | ||
contexts=documents, | ||
samples=prediction["samples"], | ||
exclude=[], | ||
mdl_hash=self.mdl_hash, | ||
project_id = self.project_id, | ||
|
||
) | ||
ask_request.query_id = str(uuid.uuid4().int >> 65) | ||
await self.ask_service_var.ask(ask_request, service_metadata=self.service_metadata) | ||
# getting ask result | ||
ask_result_response = self.ask_service_var.get_ask_result( | ||
AskResultRequest( | ||
query_id=ask_request.query_id, | ||
) | ||
) | ||
|
||
prediction["actual_output"] = actual_output | ||
prediction["retrieval_context"] = extract_units(documents) | ||
while ( | ||
ask_result_response.status != "finished" | ||
and ask_result_response.status != "failed" | ||
): | ||
# getting ask result | ||
ask_result_response = self.ask_service_var.get_ask_result( | ||
AskResultRequest( | ||
query_id=ask_request.query_id, | ||
) | ||
) | ||
|
||
# result = await self._retrieval.run(query=prediction["input"]) | ||
# documents = result.get("construct_retrieval_results", []) | ||
# actual_output = await self._generation.run( | ||
# query=prediction["input"], | ||
# contexts=documents, | ||
# samples=prediction["samples"], | ||
# exclude=[], | ||
# ) | ||
|
||
prediction["actual_output"] = ask_result_response.response[0].sql | ||
#prediction["retrieval_context"] = extract_units(documents) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add timeout, backoff, and error handling to the polling loop.
The current implementation has several issues:
- No timeout mechanism for the polling loop
- No backoff strategy between retries
- No error handling for service calls
Add these safety mechanisms:
async def _process(self, prediction: dict, **_) -> dict:
+ MAX_RETRIES = 10
+ INITIAL_BACKOFF = 1 # seconds
+ MAX_BACKOFF = 32 # seconds
+
+ try:
await self.indexing_service_var.prepare_semantics(
SemanticsPreparationRequest(
mdl=self.mdl_str_var,
mdl_hash=self.mdl_hash,
project_id=self.project_id
),
service_metadata=self.service_metadata,
)
ask_request = AskRequest(
query=prediction["input"],
mdl_hash=self.mdl_hash,
project_id=self.project_id,
)
ask_request.query_id = str(uuid.uuid4().int >> 65)
await self.ask_service_var.ask(ask_request, service_metadata=self.service_metadata)
ask_result_response = self.ask_service_var.get_ask_result(
AskResultRequest(
query_id=ask_request.query_id,
)
)
+ retries = 0
+ backoff = INITIAL_BACKOFF
while (
ask_result_response.status != "finished"
and ask_result_response.status != "failed"
+ and retries < MAX_RETRIES
):
+ await asyncio.sleep(backoff)
+ backoff = min(backoff * 2, MAX_BACKOFF)
+ retries += 1
+
ask_result_response = self.ask_service_var.get_ask_result(
AskResultRequest(
query_id=ask_request.query_id,
)
)
+ if retries >= MAX_RETRIES:
+ raise TimeoutError("Ask service request timed out")
+
+ if ask_result_response.status == "failed":
+ raise RuntimeError(f"Ask service request failed: {ask_result_response.error}")
+
prediction["actual_output"] = ask_result_response.response[0].sql
+ except Exception as e:
+ logger.exception("Error in _process: %s", str(e))
+ prediction["actual_output"] = None
+ prediction["error"] = str(e)
return prediction
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
await self.indexing_service_var.prepare_semantics( | |
SemanticsPreparationRequest( | |
mdl=self.mdl_str_var, | |
mdl_hash=self.mdl_hash, | |
project_id=self.project_id | |
), | |
service_metadata=self.service_metadata, | |
) | |
# asking | |
ask_request = AskRequest( | |
query=prediction["input"], | |
contexts=documents, | |
samples=prediction["samples"], | |
exclude=[], | |
mdl_hash=self.mdl_hash, | |
project_id = self.project_id, | |
) | |
ask_request.query_id = str(uuid.uuid4().int >> 65) | |
await self.ask_service_var.ask(ask_request, service_metadata=self.service_metadata) | |
# getting ask result | |
ask_result_response = self.ask_service_var.get_ask_result( | |
AskResultRequest( | |
query_id=ask_request.query_id, | |
) | |
) | |
prediction["actual_output"] = actual_output | |
prediction["retrieval_context"] = extract_units(documents) | |
while ( | |
ask_result_response.status != "finished" | |
and ask_result_response.status != "failed" | |
): | |
# getting ask result | |
ask_result_response = self.ask_service_var.get_ask_result( | |
AskResultRequest( | |
query_id=ask_request.query_id, | |
) | |
) | |
# result = await self._retrieval.run(query=prediction["input"]) | |
# documents = result.get("construct_retrieval_results", []) | |
# actual_output = await self._generation.run( | |
# query=prediction["input"], | |
# contexts=documents, | |
# samples=prediction["samples"], | |
# exclude=[], | |
# ) | |
prediction["actual_output"] = ask_result_response.response[0].sql | |
#prediction["retrieval_context"] = extract_units(documents) | |
async def _process(self, prediction: dict, **_) -> dict: | |
MAX_RETRIES = 10 | |
INITIAL_BACKOFF = 1 # seconds | |
MAX_BACKOFF = 32 # seconds | |
try: | |
await self.indexing_service_var.prepare_semantics( | |
SemanticsPreparationRequest( | |
mdl=self.mdl_str_var, | |
mdl_hash=self.mdl_hash, | |
project_id=self.project_id | |
), | |
service_metadata=self.service_metadata, | |
) | |
ask_request = AskRequest( | |
query=prediction["input"], | |
mdl_hash=self.mdl_hash, | |
project_id=self.project_id, | |
) | |
ask_request.query_id = str(uuid.uuid4().int >> 65) | |
await self.ask_service_var.ask(ask_request, service_metadata=self.service_metadata) | |
ask_result_response = self.ask_service_var.get_ask_result( | |
AskResultRequest( | |
query_id=ask_request.query_id, | |
) | |
) | |
retries = 0 | |
backoff = INITIAL_BACKOFF | |
while ( | |
ask_result_response.status != "finished" | |
and ask_result_response.status != "failed" | |
and retries < MAX_RETRIES | |
): | |
await asyncio.sleep(backoff) | |
backoff = min(backoff * 2, MAX_BACKOFF) | |
retries += 1 | |
ask_result_response = self.ask_service_var.get_ask_result( | |
AskResultRequest( | |
query_id=ask_request.query_id, | |
) | |
) | |
if retries >= MAX_RETRIES: | |
raise TimeoutError("Ask service request timed out") | |
if ask_result_response.status == "failed": | |
raise RuntimeError(f"Ask service request failed: {ask_result_response.error}") | |
prediction["actual_output"] = ask_result_response.response[0].sql | |
except Exception as e: | |
logger.exception("Error in _process: %s", str(e)) | |
prediction["actual_output"] = None | |
prediction["error"] = str(e) | |
return prediction |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (2)
wren-ai-service/eval/pipelines.py (2)
375-380
: Remove or document commented code.There are multiple lines of commented code without explanation. Either remove them if they're no longer needed or add a comment explaining why they're preserved.
- # document_store_provider.get_store(recreate_index=True) - # _indexing = indexing.Indexing( - # embedder_provider=embedder_provider, - # document_store_provider=document_store_provider, - # ) - # deploy_model(mdl, _indexing) + # TODO: Remove after confirming new component-based indexing is stable + # Legacy indexing code kept for reference: + # document_store_provider.get_store(recreate_index=True) + # _indexing = indexing.Indexing( + # embedder_provider=embedder_provider, + # document_store_provider=document_store_provider, + # ) + # deploy_model(mdl, _indexing)
485-488
: Add type hints for service_metadata parameter.The
service_metadata
parameter lacks type hints, which could lead to type-related issues.def init( name: Literal["retrieval", "generation", "ask"], meta: dict, mdl: dict, - service_metadata, + service_metadata: Dict[str, Any], pipe_components: Dict[str, Any], ) -> Eval:
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
wren-ai-service/eval/pipelines.py
(7 hunks)wren-ai-service/eval/prediction.py
(4 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
wren-ai-service/eval/prediction.py
25-25: src.globals.create_service_container
imported but unused
Remove unused import: src.globals.create_service_container
(F401)
145-145: Undefined name provider
(F821)
wren-ai-service/eval/pipelines.py
10-10: orjson
imported but unused
Remove unused import: orjson
(F401)
15-15: src.config.settings
imported but unused
Remove unused import: src.config.settings
(F401)
16-16: src.providers.generate_components
imported but unused
Remove unused import: src.providers.generate_components
(F401)
24-24: src.web.v1.services.ask.AskResultResponse
imported but unused
Remove unused import: src.web.v1.services.ask.AskResultResponse
(F401)
🔇 Additional comments (5)
wren-ai-service/eval/prediction.py (2)
53-56
: LGTM! Configuration management improvement.Good refactor to use centralized settings with fallback values.
17-27
:⚠️ Potential issueFix missing provider import causing runtime error.
The code attempts to use
provider.init_providers
on line 145, but theprovider
module is no longer imported.Add the following import:
+from src.providers import provider
Likely invalid or redundant comment.
🧰 Tools
🪛 Ruff (0.8.2)
25-25:
src.globals.create_service_container
imported but unusedRemove unused import:
src.globals.create_service_container
(F401)
wren-ai-service/eval/pipelines.py (3)
355-363
: 🛠️ Refactor suggestionAdd protection against circular references and unhashable keys.
The dictionary to string conversion is vulnerable to:
- Infinite recursion from circular references
- TypeError from unhashable dictionary keys
Add these safety mechanisms:
- def dict_to_string(self, d: dict) -> str: + def dict_to_string(self, d: dict, seen=None) -> str: + if seen is None: + seen = set() + if not isinstance(d, dict): return str(d) + # Check for circular references + d_id = id(d) + if d_id in seen: + return "{...}" # Indicate circular reference + seen.add(d_id) result = "{" - for key, value in d.items(): - result += f"'{key}': {self.dict_to_string(value)}, " + try: + for key, value in d.items(): + result += f"'{key}': {self.dict_to_string(value, seen)}, " + except TypeError as e: + return f"{{Error: {str(e)}}}" result = result.rstrip(", ") + "}" + seen.remove(d_id) return resultLikely invalid or redundant comment.
364-388
: 🛠️ Refactor suggestionAdd validation for required components.
The constructor should validate the presence of all required components before initialization to prevent runtime errors.
Add component validation:
def __init__( self, meta: dict, mdl: dict, service_metadata, pipe_components, ): super().__init__(meta, 3) + required_components = { + 'db_schema_indexing', + 'historical_question_indexing', + 'table_description_indexing', + 'intent_classification', + 'data_assistance', + 'db_schema_retrieval', + 'historical_question_retrieval', + 'sql_generation', + 'sql_correction' + } + + missing = required_components - set(pipe_components.keys()) + if missing: + raise ValueError(f"Missing required components: {missing}") + self.service_metadata = service_metadata self.pipe_components = pipe_components self.project_id = str(uuid.uuid4().int >> 65)Likely invalid or redundant comment.
398-445
:⚠️ Potential issueAdd timeout and error handling to the polling loop.
The current polling implementation lacks several critical safety mechanisms:
- No timeout to prevent infinite loops
- No delay between retries
- No error handling for failed status
Add these safety mechanisms:
async def _process(self, prediction: dict, **_) -> dict: + MAX_RETRIES = 10 + RETRY_DELAY = 1 # seconds + + try: await self.indexing_service_var.prepare_semantics( SemanticsPreparationRequest( mdl=self.mdl_str_var, mdl_hash=self.mdl_hash, project_id=self.project_id ), service_metadata=self.service_metadata, ) ask_request = AskRequest( query=prediction["input"], mdl_hash=self.mdl_hash, project_id=self.project_id, ) ask_request.query_id = str(uuid.uuid4().int >> 65) await self.ask_service_var.ask(ask_request, service_metadata=self.service_metadata) ask_result_response = self.ask_service_var.get_ask_result( AskResultRequest( query_id=ask_request.query_id, ) ) + retries = 0 while ( ask_result_response.status != "finished" and ask_result_response.status != "failed" + and retries < MAX_RETRIES ): + await asyncio.sleep(RETRY_DELAY) + retries += 1 ask_result_response = self.ask_service_var.get_ask_result( AskResultRequest( query_id=ask_request.query_id, ) ) + if retries >= MAX_RETRIES: + raise TimeoutError(f"Ask service request timed out after {MAX_RETRIES} retries") + + if ask_result_response.status == "failed": + raise RuntimeError(f"Ask service request failed: {ask_result_response.error}") + + if not ask_result_response.response: + raise ValueError("Empty response from ask service") + prediction["actual_output"] = ask_result_response.response[0].sql + except Exception as e: + prediction["error"] = str(e) + prediction["actual_output"] = None return predictionLikely invalid or redundant comment.
pipe_components = generate_components(settings.components) | ||
meta = generate_meta( | ||
path=path, | ||
dataset=dataset, | ||
pipe=pipe_name, | ||
**providers, | ||
**pipe_components["db_schema_retrieval"], | ||
) | ||
|
||
service_metadata = create_service_metadata(pipe_components) | ||
pipe = pipelines.init( | ||
pipe_name, | ||
meta, | ||
mdl=dataset["mdl"], | ||
providers=providers, | ||
service_metadata=service_metadata, | ||
pipe_components=pipe_components, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add error handling for component generation.
The component generation and service metadata creation could fail silently. Consider adding error handling:
- pipe_components = generate_components(settings.components)
- meta = generate_meta(
- path=path,
- dataset=dataset,
- pipe=pipe_name,
- **pipe_components["db_schema_retrieval"],
- )
- service_metadata = create_service_metadata(pipe_components)
+ try:
+ pipe_components = generate_components(settings.components)
+ if not pipe_components.get("db_schema_retrieval"):
+ raise ValueError("Required component 'db_schema_retrieval' not found")
+
+ meta = generate_meta(
+ path=path,
+ dataset=dataset,
+ pipe=pipe_name,
+ **pipe_components["db_schema_retrieval"],
+ )
+ service_metadata = create_service_metadata(pipe_components)
+ except Exception as e:
+ raise RuntimeError(f"Failed to initialize components: {str(e)}") from e
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
pipe_components = generate_components(settings.components) | |
meta = generate_meta( | |
path=path, | |
dataset=dataset, | |
pipe=pipe_name, | |
**providers, | |
**pipe_components["db_schema_retrieval"], | |
) | |
service_metadata = create_service_metadata(pipe_components) | |
pipe = pipelines.init( | |
pipe_name, | |
meta, | |
mdl=dataset["mdl"], | |
providers=providers, | |
service_metadata=service_metadata, | |
pipe_components=pipe_components, | |
try: | |
pipe_components = generate_components(settings.components) | |
if not pipe_components.get("db_schema_retrieval"): | |
raise ValueError("Required component 'db_schema_retrieval' not found") | |
meta = generate_meta( | |
path=path, | |
dataset=dataset, | |
pipe=pipe_name, | |
**pipe_components["db_schema_retrieval"], | |
) | |
service_metadata = create_service_metadata(pipe_components) | |
except Exception as e: | |
raise RuntimeError(f"Failed to initialize components: {str(e)}") from e | |
pipe = pipelines.init( | |
pipe_name, | |
meta, | |
mdl=dataset["mdl"], | |
service_metadata=service_metadata, | |
pipe_components=pipe_components, |
After the code is refactored, the evaluation will not work.
This PR is trying to make the evaluation work again.
But now it reports the error:
looks like engine doesn't load spider test data. I follow the instruction of https://github.com/Canner/WrenAI/tree/main/wren-ai-service/eval#eval-dataset-preparationif-using-spider-10-dataset put the data to
wren-ai-service/tools/dev/spider1.0
ISummary by CodeRabbit
Summary by CodeRabbit
New Features
Refactor
Bug Fixes
Documentation