Skip to content

Commit

Permalink
Make the async task and sync task api actually return HTTP response c…
Browse files Browse the repository at this point in the history
…odes from the user container (#676)

* comment

* custom status code in echo server

* catch InvalidRequestException from sync/streaming inference gateways

* mark where we'd probably add the status_code

* status code to dto

* add to sync dto also

* fix tests

* task queue gateway returns status code if possible

* forwarder stuff

* black

* try fixing integration tests

* try fixing unit tests

* missed spots

* test case for cov

* oops

* fix

* ...

* ugh

* eh just remove status code from result manually

* revert integration test changes
  • Loading branch information
seanshi-scale authored Jan 21, 2025
1 parent 69875fc commit f3e466c
Show file tree
Hide file tree
Showing 20 changed files with 161 additions and 22 deletions.
2 changes: 1 addition & 1 deletion integration_tests/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def test_sync_streaming_model_endpoint(capsys):
for response in task_responses:
assert (
response.strip()
== 'data: {"status":"SUCCESS","result":{"result":{"y":1}},"traceback":null}'
== 'data: {"status":"SUCCESS","result":{"result":{"y":1}},"traceback":null,"status_code":200}'
)
finally:
delete_model_endpoint(create_endpoint_request["name"], user)
Expand Down
16 changes: 14 additions & 2 deletions model-engine/model_engine_server/api/tasks_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ async def create_sync_inference_task(
)
except UpstreamServiceError as exc:
return SyncEndpointPredictV1Response(
status=TaskStatus.FAILURE, traceback=exc.content.decode()
status=TaskStatus.FAILURE, traceback=exc.content.decode(), status_code=exc.status_code
)
except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc:
raise HTTPException(
Expand All @@ -133,6 +133,11 @@ async def create_sync_inference_task(
status_code=408,
detail="Request timed out.",
) from exc
except InvalidRequestException as exc:
raise HTTPException(
status_code=400,
detail=f"Invalid request: {str(exc)}",
) from exc


@inference_task_router_v1.post("/streaming-tasks")
Expand Down Expand Up @@ -164,7 +169,9 @@ async def event_generator():
iter(
(
SyncEndpointPredictV1Response(
status=TaskStatus.FAILURE, traceback=exc.content.decode()
status=TaskStatus.FAILURE,
traceback=exc.content.decode(),
status_code=exc.status_code,
).json(),
)
)
Expand All @@ -179,3 +186,8 @@ async def event_generator():
status_code=400,
detail=f"Unsupported inference type: {str(exc)}",
) from exc
except InvalidRequestException as exc:
raise HTTPException(
status_code=400,
detail=f"Invalid request: {str(exc)}",
) from exc
2 changes: 2 additions & 0 deletions model-engine/model_engine_server/common/dtos/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@ class GetAsyncTaskV1Response(BaseModel):
status: TaskStatus
result: Optional[ResponseSchema] = None
traceback: Optional[str] = None
status_code: Optional[int] = None


class SyncEndpointPredictV1Response(BaseModel):
status: TaskStatus
result: Optional[Any] = None
traceback: Optional[str] = None
status_code: Optional[int] = None


class EndpointPredictV1Request(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ forwarder:
model_engine_unwrap: false
serialize_results_as_string: false
wrap_response: false
forward_http_status_in_body: true
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ forwarder:
batch_route: null
model_engine_unwrap: true
serialize_results_as_string: true
forward_http_status_in_body: true
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ def exec_func(payload, arrival_timestamp, *ignored_args, **ignored_kwargs):
logger.warning(f"Ignoring {len(ignored_kwargs)} keyword arguments: {ignored_kwargs=}")
try:
monitoring_metrics_gateway.emit_async_task_received_metric(queue_name)
# Don't fail the celery task even if there's a status code
# (otherwise we can't really control what gets put in the result attribute)
# in the task (https://docs.celeryq.dev/en/stable/reference/celery.result.html#celery.result.AsyncResult.status)
result = forwarder(payload)
request_duration = datetime.now() - arrival_timestamp
if request_duration > timedelta(seconds=DEFAULT_TASK_VISIBILITY_SECONDS):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,21 @@ async def predict(request: Request):
print("Received request", dictionary, flush=True)
if "delay" in dictionary:
await asyncio.sleep(dictionary["delay"])
if "status_code" in dictionary:
return JSONResponse(content=dictionary, status_code=dictionary["status_code"])
return dictionary


@app.post("/predict500")
async def predict500(request: Request):
response = JSONResponse(content=await request.json(), status_code=500)
dictionary = await request.json()
if "delay" in dictionary:
await asyncio.sleep(dictionary["delay"])
if "status_code" in dictionary:
status_code = dictionary["status_code"]
else:
status_code = 500
response = JSONResponse(content=dictionary, status_code=status_code)
return response


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, AsyncGenerator, Iterable, List, Optional, Sequence, Tuple
from typing import Any, AsyncGenerator, Dict, Iterable, List, Optional, Sequence, Tuple

import aiohttp
import orjson
Expand Down Expand Up @@ -101,13 +101,24 @@ def unwrap_json_payload(self, json_payload: Any) -> Tuple[Any, bool]:
return json_payload, using_serialize_results_as_string

@staticmethod
def get_response_payload(using_serialize_results_as_string: bool, response: Any):
def get_response_payload(
using_serialize_results_as_string: bool,
forward_http_status_in_body: bool,
response: Any,
status_code: int,
) -> Any:
# Model Engine expects a JSON object with a "result" key.

response_payload: Dict[str, Any] = {}
if using_serialize_results_as_string:
response_as_string: str = json.dumps(response)
return {"result": response_as_string}
response_payload["result"] = response_as_string
else:
response_payload["result"] = response

return {"result": response}
if forward_http_status_in_body:
response_payload["status_code"] = status_code
return response_payload

@staticmethod
def get_response_payload_stream(using_serialize_results_as_string: bool, response: str):
Expand Down Expand Up @@ -148,7 +159,12 @@ class Forwarder(ModelEngineSerializationMixin):
model_engine_unwrap: bool
serialize_results_as_string: bool
wrap_response: bool
forward_http_status: bool
# See celery_task_queue_gateway.py for why we should keep wrap_response as True
# for async. tl;dr is we need to convey both the result as well as status code.
forward_http_status: bool # Forwards http status in JSONResponse
# Forwards http status in the response body. Only used if wrap_response is True
# We do this to avoid having to put this data in any sync response and only do it for async responses
forward_http_status_in_body: bool
post_inference_hooks_handler: Optional[PostInferenceHooksHandler] = None

async def forward(self, json_payload: Any) -> Any:
Expand Down Expand Up @@ -191,7 +207,12 @@ async def forward(self, json_payload: Any) -> Any:
)

if self.wrap_response:
response = self.get_response_payload(using_serialize_results_as_string, response)
response = self.get_response_payload(
using_serialize_results_as_string,
self.forward_http_status_in_body,
response,
response_raw.status,
)

if self.forward_http_status:
return JSONResponse(content=response, status_code=response_raw.status)
Expand Down Expand Up @@ -233,7 +254,12 @@ def __call__(self, json_payload: Any) -> Any:
)

if self.wrap_response:
response = self.get_response_payload(using_serialize_results_as_string, response)
response = self.get_response_payload(
using_serialize_results_as_string,
self.forward_http_status_in_body,
response,
response_raw.status_code,
)

if self.forward_http_status:
return JSONResponse(content=response, status_code=response_raw.status_code)
Expand Down Expand Up @@ -263,6 +289,7 @@ class LoadForwarder:
serialize_results_as_string: bool = True
wrap_response: bool = True
forward_http_status: bool = False
forward_http_status_in_body: bool = False

def load(self, resources: Optional[Path], cache: Any) -> Forwarder:
if self.use_grpc:
Expand Down Expand Up @@ -370,6 +397,7 @@ def endpoint(route: str) -> str:
post_inference_hooks_handler=handler,
wrap_response=self.wrap_response,
forward_http_status=self.forward_http_status,
forward_http_status_in_body=self.forward_http_status_in_body,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def send_task(
kwargs: Optional[Dict[str, Any]] = None,
expires: Optional[int] = None,
) -> CreateAsyncTaskV1Response:
# Used for both endpoint infra creation and async tasks
celery_dest = self._get_celery_dest()

try:
Expand All @@ -84,6 +85,7 @@ def send_task(
return CreateAsyncTaskV1Response(task_id=res.id)

def get_task(self, task_id: str) -> GetAsyncTaskV1Response:
# Only used for async tasks
celery_dest = self._get_celery_dest()
res = celery_dest.AsyncResult(task_id)
response_state = res.state
Expand All @@ -92,15 +94,27 @@ def get_task(self, task_id: str) -> GetAsyncTaskV1Response:
# result_dict = (
# response_result if type(response_result) is dict else {"result": response_result}
# )
status_code = None
result = res.result
if type(result) is dict and "status_code" in result:
# Filter out status code from result if it was added by the forwarder
# This is admittedly kinda hacky and would technically introduce an edge case
# if we ever decide not to have async tasks wrap response.
status_code = result["status_code"]
del result["status_code"]
return GetAsyncTaskV1Response(
task_id=task_id, status=TaskStatus.SUCCESS, result=res.result
task_id=task_id,
status=TaskStatus.SUCCESS,
result=result,
status_code=status_code,
)

elif response_state == "FAILURE":
return GetAsyncTaskV1Response(
task_id=task_id,
status=TaskStatus.FAILURE,
traceback=res.traceback,
status_code=None, # probably
)
elif response_state == "RETRY":
# Backwards compatibility, otherwise we'd need to add "RETRY" to the clients
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,9 @@ async def streaming_predict(
endpoint_name=endpoint_name or topic,
)
async for item in response:
yield SyncEndpointPredictV1Response(status=TaskStatus.SUCCESS, result=item)
yield SyncEndpointPredictV1Response(
status=TaskStatus.SUCCESS, result=item, status_code=200
)
except UpstreamServiceError as exc:
logger.error(f"Service error on streaming task: {exc.content!r}")

Expand All @@ -258,4 +260,5 @@ async def streaming_predict(
yield SyncEndpointPredictV1Response(
status=TaskStatus.FAILURE,
traceback=result_traceback,
status_code=exc.status_code,
)
Original file line number Diff line number Diff line change
Expand Up @@ -238,12 +238,17 @@ async def predict(
return SyncEndpointPredictV1Response(
status=TaskStatus.FAILURE,
traceback=result_traceback,
status_code=exc.status_code,
)

except Exception as e:
logger.error(f"Failed to parse error: {e}")
return SyncEndpointPredictV1Response(
status=TaskStatus.FAILURE, traceback=exc.content.decode()
status=TaskStatus.FAILURE,
traceback=exc.content.decode(),
status_code=exc.status_code,
)

return SyncEndpointPredictV1Response(status=TaskStatus.SUCCESS, result=response)
return SyncEndpointPredictV1Response(
status=TaskStatus.SUCCESS, result=response, status_code=200
)
1 change: 1 addition & 0 deletions model-engine/tests/unit/api/test_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def test_completion_sync_success(
}"""
},
traceback=None,
status_code=200,
),
)
response_1 = client.post(
Expand Down
3 changes: 2 additions & 1 deletion model-engine/tests/unit/api/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,8 @@ async def test_create_streaming_task_success(
count = 0
async for message in response.aiter_bytes():
assert (
message == b'data: {"status":"SUCCESS","result":null,"traceback":null}\r\n\r\n'
message
== b'data: {"status":"SUCCESS","result":null,"traceback":null,"status_code":200}\r\n\r\n'
)
count += 1
assert count == 1
10 changes: 6 additions & 4 deletions model-engine/tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,18 +1028,17 @@ def get_task_args(self, task_id: str):

def get_task(self, task_id: str) -> GetAsyncTaskV1Response:
result = None
status_code = None
if task_id in self.queue:
status = TaskStatus.PENDING
elif task_id in self.completed:
status = TaskStatus.SUCCESS
result = self.completed[task_id]
status_code = 200
else:
status = TaskStatus.UNDEFINED
return GetAsyncTaskV1Response(
task_id=task_id,
status=status,
result=result,
traceback=None,
task_id=task_id, status=status, result=result, traceback=None, status_code=status_code
)

def clear_queue(self, queue_name: str) -> bool:
Expand Down Expand Up @@ -1537,6 +1536,7 @@ def __init__(self):
status=TaskStatus.SUCCESS,
result=None,
traceback=None,
status_code=200,
)
]

Expand All @@ -1561,6 +1561,7 @@ def __init__(self, fake_sync_inference_content=None):
status=TaskStatus.SUCCESS,
result=None,
traceback=None,
status_code=200,
)
else:
self.response = fake_sync_inference_content
Expand Down Expand Up @@ -1662,6 +1663,7 @@ def get_task(self, task_id: str) -> GetAsyncTaskV1Response:
status=TaskStatus.SUCCESS,
result=None,
traceback=None,
status_code=200,
)

def get_last_request(self):
Expand Down
Loading

0 comments on commit f3e466c

Please sign in to comment.