diff --git a/.github/workflows/pythonpackage.yml b/.github/workflows/pythonpackage.yml index 56c5d2d64..3e7c0f3c4 100644 --- a/.github/workflows/pythonpackage.yml +++ b/.github/workflows/pythonpackage.yml @@ -1,6 +1,6 @@ name: Python package -on: [push] +on: [push, pull_request] jobs: build: diff --git a/tests/services/start_process_service_http_0.py b/tests/services/start_process_service_http_0.py new file mode 100644 index 000000000..4ae0fe091 --- /dev/null +++ b/tests/services/start_process_service_http_0.py @@ -0,0 +1,91 @@ +import asyncio +import os +import signal +from typing import Any # noqa + +from aiohttp import web + +import tomodachi +from tomodachi.transport.http import http + + +@tomodachi.service +class HttpService(tomodachi.Service): + name = "test_http" + options = {"http": {"port": 53250, "access_log": True, "real_ip_from": "127.0.0.1"}} + uuid = None + closer: asyncio.Future = asyncio.Future() + function_order = [] + + @http("GET", r"/get-uuid/?") + async def get_uuid(self, request: web.Request) -> str: + return self.uuid + + async def _start_service(self) -> None: + self.function_order.append("_start_service") + + async def _started_service(self) -> None: + self.function_order.append("_started_service") + + async def _async() -> None: + async def sleep_and_kill() -> None: + await asyncio.sleep(4.0) + if not self.closer.done(): + self.closer.set_result(None) + + task = asyncio.ensure_future(sleep_and_kill()) + await self.closer + if not task.done(): + task.cancel() + os.kill(os.getpid(), signal.SIGINT) + + asyncio.ensure_future(_async()) + + def stop_service(self) -> None: + self.function_order.append("stop_service") + if not self.closer.done(): + self.closer.set_result(None) + + async def _stop_service(self) -> None: + self.function_order.append("_stop_service") + + +@tomodachi.service +class HttpService2(tomodachi.Service): + name = "test_http2" + options = {"http": {"port": 53250, "access_log": True, "real_ip_from": "127.0.0.1"}} + uuid = None + closer: asyncio.Future = asyncio.Future() + function_order = [] + + @http("GET", r"/get-uuid/?") + async def get_uuid(self, request: web.Request) -> str: + return self.uuid + + async def _start_service(self) -> None: + self.function_order.append("_start_service") + + async def _started_service(self) -> None: + self.function_order.append("_started_service") + + async def _async() -> None: + async def sleep_and_kill() -> None: + await asyncio.sleep(5.0) + if not self.closer.done(): + self.closer.set_result(None) + + task = asyncio.ensure_future(sleep_and_kill()) + await self.closer + if not task.done(): + task.cancel() + os.kill(os.getpid(), signal.SIGINT) + + asyncio.ensure_future(_async()) + + def stop_service(self) -> None: + self.function_order.append("stop_service") + if not self.closer.done(): + self.closer.set_result(None) + + async def _stop_service(self) -> None: + self.function_order.append("_stop_service") diff --git a/tests/test_http_service.py b/tests/test_http_service.py index ae485c6c4..137ee01af 100644 --- a/tests/test_http_service.py +++ b/tests/test_http_service.py @@ -3,6 +3,7 @@ import mimetypes import os import pathlib +import platform from typing import Any import aiohttp @@ -28,6 +29,10 @@ def test_start_http_service(monkeypatch: Any, capsys: Any, loop: Any) -> None: loop.run_until_complete(future) +@pytest.mark.skipif( + platform.system() == "Linux", + reason="SO_REUSEPORT is automatically enable on Linux", +) def test_conflicting_port_http_service(monkeypatch: Any, capsys: Any, loop: Any) -> None: services, future = start_service("tests/services/http_service_same_port.py", monkeypatch) @@ -343,7 +348,7 @@ def test_access_log(monkeypatch: Any, loop: Any) -> None: assert os.path.exists(log_path) is True with open(log_path) as file: content = file.read() - assert content == "Listening [http] on http://127.0.0.1:{}/\n".format(port) + assert "Listening [http] on http://127.0.0.1:{}/\n".format(port) in content async def _async(loop: Any) -> None: async with aiohttp.ClientSession(loop=loop) as client: @@ -397,7 +402,7 @@ async def _async(loop: Any) -> None: with open(log_path) as file: content = file.read() - assert content == "Listening [http] on http://127.0.0.1:{}/\n".format(port) + assert "Listening [http] on http://127.0.0.1:{}/\n".format(port) in content loop.run_until_complete(_async(loop)) instance.stop_service() diff --git a/tests/test_start_2_process_http_0.py b/tests/test_start_2_process_http_0.py new file mode 100644 index 000000000..9cb04477d --- /dev/null +++ b/tests/test_start_2_process_http_0.py @@ -0,0 +1,49 @@ +import asyncio +import platform +from typing import Any + +import aiohttp +import pytest + +from run_test_service_helper import start_service + + +@pytest.mark.skipif( + platform.system() != "Linux", + reason="SO_REUSEPORT can only be enable on Linux", +) +def test_start_2_process_http_reuse_port_request(monkeypatch: Any, capsys: Any, loop: Any) -> None: + func, future = start_service("tests/services/start_process_service_http_0.py", monkeypatch, wait=False) + + port = 53250 + + async def _async(loop: Any) -> None: + await asyncio.sleep(1) + async with aiohttp.ClientSession(loop=loop) as client: + services_uuid = set() + for ti in range(4): + response = await client.get("http://127.0.0.1:{}/get-uuid".format(port)) + data = await response.read() + assert len(data) > 0 + services_uuid.add(str(data)) + assert len(services_uuid) == 2 + + loop.run_until_complete(_async(loop)) + loop.run_until_complete(future) + + services = func() + assert services is not None + assert len(services) == 2 + instance1 = services.get("test_http") + assert instance1 is not None + assert instance1.uuid is not None + instance2 = services.get("test_http2") + assert instance2 is not None + assert instance2.uuid is not None + + assert instance1.uuid != instance2.uuid + assert instance1.function_order == ["_start_service", "_started_service", "_stop_service"] + assert instance2.function_order == ["_start_service", "_started_service", "_stop_service"] + + instance1.stop_service() + instance2.stop_service() diff --git a/tomodachi/transport/http.py b/tomodachi/transport/http.py index 80ac0b215..372fa4fbc 100644 --- a/tomodachi/transport/http.py +++ b/tomodachi/transport/http.py @@ -5,6 +5,7 @@ import logging import os import pathlib +import platform import re import time import uuid @@ -1075,7 +1076,21 @@ async def func() -> Union[web.Response, web.FileResponse]: keepalive_timeout=keepalive_timeout, tcp_keepalive=tcp_keepalive, ) - server_task = loop.create_server(web_server, host, port) # type: ignore + if platform.system() == "Linux": + reuse_port = True + if port == 0: + http_logger.warning( + "listen on random port (0) with SO_REUSEPORT is dangerous." + " Please double check your intent." + ) + else: + http_logger.warning( + "SO_REUSEPORT set to True automatically for Linux platform." + " Different service must not use the same port ({})".format(port) + ) + else: + reuse_port = False + server_task = loop.create_server(web_server, host, port, reuse_port=reuse_port) # type: ignore server = await server_task # type: ignore except OSError as e: context["_http_accept_new_requests"] = False