diff --git a/nextline/spawned/events.py b/nextline/events.py similarity index 78% rename from nextline/spawned/events.py rename to nextline/events.py index 43680176..86d90d8d 100644 --- a/nextline/spawned/events.py +++ b/nextline/events.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from typing import Optional -from nextline.types import PromptNo, TaskNo, ThreadNo, TraceNo +from nextline.types import PromptNo, RunNo, TaskNo, ThreadNo, TraceNo @dataclass @@ -13,6 +13,7 @@ class Event: @dataclass class OnStartTrace(Event): started_at: datetime.datetime + run_no: RunNo trace_no: TraceNo thread_no: ThreadNo task_no: Optional[TaskNo] @@ -21,48 +22,59 @@ class OnStartTrace(Event): @dataclass class OnEndTrace(Event): ended_at: datetime.datetime + run_no: RunNo trace_no: TraceNo @dataclass class OnStartTraceCall(Event): started_at: datetime.datetime + run_no: RunNo trace_no: TraceNo file_name: str line_no: int frame_object_id: int - call_event: str + event: str @dataclass class OnEndTraceCall(Event): ended_at: datetime.datetime + run_no: RunNo trace_no: TraceNo @dataclass class OnStartCmdloop(Event): started_at: datetime.datetime + run_no: RunNo trace_no: TraceNo @dataclass class OnEndCmdloop(Event): ended_at: datetime.datetime + run_no: RunNo trace_no: TraceNo @dataclass class OnStartPrompt(Event): started_at: datetime.datetime + run_no: RunNo trace_no: TraceNo prompt_no: PromptNo prompt_text: str + file_name: str + line_no: int + frame_object_id: int + event: str @dataclass class OnEndPrompt(Event): ended_at: datetime.datetime + run_no: RunNo trace_no: TraceNo prompt_no: PromptNo command: str @@ -71,5 +83,6 @@ class OnEndPrompt(Event): @dataclass class OnWriteStdout(Event): written_at: datetime.datetime + run_no: RunNo trace_no: TraceNo text: str diff --git a/nextline/plugin/plugins/__init__.py b/nextline/plugin/plugins/__init__.py index 56de127d..539bb0cb 100644 --- a/nextline/plugin/plugins/__init__.py +++ b/nextline/plugin/plugins/__init__.py @@ -14,7 +14,7 @@ TraceInfoRegistrar, TraceNumbersRegistrar, ) -from .session import CommandSender, Result, RunSession, Signal +from .session import CommandSender, Result, RunSession, Signal, OnEvent def register(hook: PluginManager) -> None: @@ -31,4 +31,5 @@ def register(hook: PluginManager) -> None: hook.register(Result) hook.register(Signal) hook.register(CommandSender) + hook.register(OnEvent) hook.register(RunSession) diff --git a/nextline/plugin/plugins/registrars/prompt_info.py b/nextline/plugin/plugins/registrars/prompt_info.py index 9e63f03a..54595327 100644 --- a/nextline/plugin/plugins/registrars/prompt_info.py +++ b/nextline/plugin/plugins/registrars/prompt_info.py @@ -2,8 +2,7 @@ import dataclasses from logging import getLogger -from nextline.plugin.spec import Context, hookimpl -from nextline.spawned import ( +from nextline.events import ( OnEndPrompt, OnEndTrace, OnEndTraceCall, @@ -11,6 +10,7 @@ OnStartTrace, OnStartTraceCall, ) +from nextline.plugin.spec import Context, hookimpl from nextline.types import PromptInfo, PromptNo, TraceNo @@ -95,7 +95,7 @@ async def on_end_trace_call(self, context: Context, event: OnEndTraceCall) -> No trace_no=trace_no, prompt_no=PromptNo(-1), open=False, - event=trace_call.call_event, + event=trace_call.event, file_name=trace_call.file_name, line_no=trace_call.line_no, trace_call_end=True, @@ -118,7 +118,7 @@ async def on_start_prompt(self, context: Context, event: OnStartPrompt) -> None: trace_no=trace_no, prompt_no=prompt_no, open=True, - event=trace_call.call_event, + event=trace_call.event, file_name=trace_call.file_name, line_no=trace_call.line_no, stdout=event.prompt_text, diff --git a/nextline/plugin/plugins/registrars/prompt_notice.py b/nextline/plugin/plugins/registrars/prompt_notice.py index 3c0e2086..ad1a3bb7 100644 --- a/nextline/plugin/plugins/registrars/prompt_notice.py +++ b/nextline/plugin/plugins/registrars/prompt_notice.py @@ -1,8 +1,8 @@ from logging import getLogger from typing import Optional +from nextline.events import OnEndTraceCall, OnStartPrompt, OnStartTraceCall from nextline.plugin.spec import Context, hookimpl -from nextline.spawned import OnEndTraceCall, OnStartPrompt, OnStartTraceCall from nextline.types import PromptNotice, RunNo, TraceNo @@ -40,7 +40,7 @@ async def on_start_prompt(self, context: Context, event: OnStartPrompt) -> None: trace_no=trace_no, prompt_no=prompt_no, prompt_text=event.prompt_text, - event=trace_call.call_event, + event=trace_call.event, file_name=trace_call.file_name, line_no=trace_call.line_no, ) diff --git a/nextline/plugin/plugins/registrars/run_info.py b/nextline/plugin/plugins/registrars/run_info.py index d1d9e367..20b330d5 100644 --- a/nextline/plugin/plugins/registrars/run_info.py +++ b/nextline/plugin/plugins/registrars/run_info.py @@ -1,54 +1,53 @@ import dataclasses -import datetime +from datetime import timezone from typing import Optional -from nextline import spawned from nextline.plugin.spec import Context, hookimpl from nextline.types import RunInfo -from nextline.utils import ExitedProcess class RunInfoRegistrar: def __init__(self) -> None: - self._script: Optional[str] = None self._run_info: Optional[RunInfo] = None - @hookimpl - async def on_change_script(self, script: str) -> None: - self._script = script - @hookimpl async def on_initialize_run(self, context: Context) -> None: assert context.run_arg + if isinstance(context.run_arg.statement, str): + script = context.run_arg.statement + else: + script = None self._run_info = RunInfo( - run_no=context.run_arg.run_no, state='initialized', script=self._script + run_no=context.run_arg.run_no, state='initialized', script=script ) await context.pubsub.publish('run_info', self._run_info) @hookimpl async def on_start_run(self, context: Context) -> None: assert self._run_info is not None + assert context.running_process + assert context.running_process.process_created_at.tzinfo is timezone.utc + started_at = context.running_process.process_created_at.replace(tzinfo=None) self._run_info = dataclasses.replace( - self._run_info, - state='running', - started_at=datetime.datetime.utcnow(), + self._run_info, state='running', started_at=started_at ) await context.pubsub.publish('run_info', self._run_info) @hookimpl - async def on_end_run( - self, context: Context, exited_process: ExitedProcess[spawned.RunResult] - ) -> None: + async def on_end_run(self, context: Context) -> None: assert self._run_info is not None - run_result = exited_process.returned or spawned.RunResult(ret=None, exc=None) + assert context.exited_process + run_result = context.exited_process.returned + assert run_result + assert context.exited_process.process_exited_at.tzinfo is timezone.utc + ended_at = context.exited_process.process_exited_at.replace(tzinfo=None) self._run_info = dataclasses.replace( self._run_info, state='finished', result=run_result.fmt_ret, exception=run_result.fmt_exc, - ended_at=datetime.datetime.utcnow(), + ended_at=ended_at, ) await context.pubsub.publish('run_info', self._run_info) - self._run_info = None diff --git a/nextline/plugin/plugins/registrars/stdout.py b/nextline/plugin/plugins/registrars/stdout.py index 8fea4252..795d9379 100644 --- a/nextline/plugin/plugins/registrars/stdout.py +++ b/nextline/plugin/plugins/registrars/stdout.py @@ -1,7 +1,7 @@ from typing import Optional +from nextline.events import OnWriteStdout from nextline.plugin.spec import Context, hookimpl -from nextline.spawned import OnWriteStdout from nextline.types import RunNo, StdoutInfo diff --git a/nextline/plugin/plugins/registrars/trace_info.py b/nextline/plugin/plugins/registrars/trace_info.py index 7fe164ee..bef35188 100644 --- a/nextline/plugin/plugins/registrars/trace_info.py +++ b/nextline/plugin/plugins/registrars/trace_info.py @@ -1,8 +1,8 @@ import dataclasses import datetime +from nextline.events import OnEndTrace, OnStartTrace from nextline.plugin.spec import Context, hookimpl -from nextline.spawned import OnEndTrace, OnStartTrace from nextline.types import TraceInfo, TraceNo diff --git a/nextline/plugin/plugins/registrars/trace_nos.py b/nextline/plugin/plugins/registrars/trace_nos.py index 58d20a8d..893419a0 100644 --- a/nextline/plugin/plugins/registrars/trace_nos.py +++ b/nextline/plugin/plugins/registrars/trace_nos.py @@ -1,5 +1,5 @@ +from nextline.events import OnEndTrace, OnStartTrace from nextline.plugin.spec import Context, hookimpl -from nextline.spawned import OnEndTrace, OnStartTrace from nextline.types import TraceNo diff --git a/nextline/plugin/plugins/session/__init__.py b/nextline/plugin/plugins/session/__init__.py index 5bacd751..4cf135b3 100644 --- a/nextline/plugin/plugins/session/__init__.py +++ b/nextline/plugin/plugins/session/__init__.py @@ -1,8 +1,10 @@ __all__ = [ + 'OnEvent', 'CommandSender', 'Result', 'RunSession', 'Signal', ] +from .monitor import OnEvent from .session import CommandSender, Result, RunSession, Signal diff --git a/nextline/plugin/plugins/session/monitor.py b/nextline/plugin/plugins/session/monitor.py index a5561d4c..2997fb93 100644 --- a/nextline/plugin/plugins/session/monitor.py +++ b/nextline/plugin/plugins/session/monitor.py @@ -1,54 +1,32 @@ -import asyncio -import time -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager from logging import getLogger -from nextline import spawned -from nextline.plugin.spec import Context -from nextline.spawned import QueueOut +from nextline import events +from nextline.plugin.spec import Context, hookimpl -@asynccontextmanager -async def relay_queue(context: Context, queue: QueueOut) -> AsyncIterator[None]: - task = asyncio.create_task(_monitor(context, queue)) - try: - yield - finally: - up_to = 0.05 - start = time.process_time() - while not queue.empty() and time.process_time() - start < up_to: - await asyncio.sleep(0) - await asyncio.to_thread(queue.put, None) # type: ignore - await task - - -async def _monitor(context: Context, queue: QueueOut) -> None: - while (event := await asyncio.to_thread(queue.get)) is not None: - await _on_event(context, event) - - -async def _on_event(context: Context, event: spawned.Event) -> None: - ahook = context.hook.ahook - match event: - case spawned.OnStartTrace(): - await ahook.on_start_trace(context=context, event=event) - case spawned.OnEndTrace(): - await ahook.on_end_trace(context=context, event=event) - case spawned.OnStartTraceCall(): - await ahook.on_start_trace_call(context=context, event=event) - case spawned.OnEndTraceCall(): - await ahook.on_end_trace_call(context=context, event=event) - case spawned.OnStartCmdloop(): - await ahook.on_start_cmdloop(context=context, event=event) - case spawned.OnEndCmdloop(): - await ahook.on_end_cmdloop(context=context, event=event) - case spawned.OnStartPrompt(): - await ahook.on_start_prompt(context=context, event=event) - case spawned.OnEndPrompt(): - await ahook.on_end_prompt(context=context, event=event) - case spawned.OnWriteStdout(): - await ahook.on_write_stdout(context=context, event=event) - case _: - logger = getLogger(__name__) - logger.warning(f'Unknown event: {event!r}') +class OnEvent: + @hookimpl + async def on_event_in_process(self, context: Context, event: events.Event) -> None: + ahook = context.hook.ahook + match event: + case events.OnStartTrace(): + await ahook.on_start_trace(context=context, event=event) + case events.OnEndTrace(): + await ahook.on_end_trace(context=context, event=event) + case events.OnStartTraceCall(): + await ahook.on_start_trace_call(context=context, event=event) + case events.OnEndTraceCall(): + await ahook.on_end_trace_call(context=context, event=event) + case events.OnStartCmdloop(): + await ahook.on_start_cmdloop(context=context, event=event) + case events.OnEndCmdloop(): + await ahook.on_end_cmdloop(context=context, event=event) + case events.OnStartPrompt(): + await ahook.on_start_prompt(context=context, event=event) + case events.OnEndPrompt(): + await ahook.on_end_prompt(context=context, event=event) + case events.OnWriteStdout(): + await ahook.on_write_stdout(context=context, event=event) + case _: + logger = getLogger(__name__) + logger.warning(f'Unknown event: {event!r}') diff --git a/nextline/plugin/plugins/session/session.py b/nextline/plugin/plugins/session/session.py index 9690f206..f09177b9 100644 --- a/nextline/plugin/plugins/session/session.py +++ b/nextline/plugin/plugins/session/session.py @@ -1,15 +1,19 @@ +import asyncio +import contextlib +import multiprocessing as mp +import time from collections.abc import AsyncIterator, Callable +from functools import partial from logging import getLogger -from typing import Any, Optional +from typing import Any, Optional, cast import apluggy from tblib import pickling_support +from nextline import spawned from nextline.plugin.spec import Context, hookimpl -from nextline.spawned import Command, RunResult -from nextline.utils import ExitedProcess, RunningProcess - -from .spawn import run_session +from nextline.spawned import Command, QueueIn, QueueOut, RunResult +from nextline.utils import run_in_process pickling_support.install() @@ -18,57 +22,104 @@ class RunSession: @hookimpl @apluggy.asynccontextmanager async def run(self, context: Context) -> AsyncIterator[None]: - ahook = context.hook.ahook - async with run_session(context) as (running, send_command): - await ahook.on_start_run( - context=context, running_process=running, send_command=send_command + assert context.run_arg + context.exited_process = None + mp_context = mp.get_context('spawn') + queue_in = cast(QueueIn, mp_context.Queue()) + queue_out = cast(QueueOut, mp_context.Queue()) + context.send_command = SendCommand(queue_in) + async with relay_events(context, queue_out): + context.running_process = await run_in_process( + func=partial(spawned.main, context.run_arg), + mp_context=mp_context, + initializer=partial(spawned.set_queues, queue_in, queue_out), + collect_logging=True, ) - yield - exited = await running - if exited.raised: - logger = getLogger(__name__) - logger.exception(exited.raised) - self._run_result = exited.returned or RunResult(ret=None, exc=None) - await ahook.on_end_run(context=context, exited_process=exited) + await context.hook.ahook.on_start_run(context=context) + try: + yield + finally: + context.exited_process = await context.running_process + if context.exited_process.returned is None: + context.exited_process.returned = RunResult() + context.running_process = None + if context.exited_process.raised: + logger = getLogger(__name__) + logger.exception(context.exited_process.raised) + await context.hook.ahook.on_end_run(context=context) + + +def SendCommand(queue_in: QueueIn) -> Callable[[Command], None]: + def _send_command(command: Command) -> None: + logger = getLogger(__name__) + logger.debug(f'send_pdb_command({command!r}') + queue_in.put(command) + + return _send_command + + +@contextlib.asynccontextmanager +async def relay_events(context: Context, queue: QueueOut) -> AsyncIterator[None]: + '''Call the hook `on_event_in_process()` on events emitted in the spawned process.''' + logger = getLogger(__name__) + + async def _monitor() -> None: + while (event := await asyncio.to_thread(queue.get)) is not None: + logger.debug(f'event: {event!r}') + await context.hook.ahook.on_event_in_process(context=context, event=event) + + task = asyncio.create_task(_monitor()) + try: + yield + finally: + up_to = 0.05 + start = time.process_time() + while not queue.empty(): + await asyncio.sleep(0) + if time.process_time() - start > up_to: + logger.warning(f'Timeout. the queue is not empty: {queue!r}') + break + await asyncio.to_thread(queue.put, None) # type: ignore + await task class Signal: @hookimpl - async def on_start_run(self, running_process: RunningProcess[RunResult]) -> None: - self._running = running_process - - @hookimpl - async def interrupt(self) -> None: - self._running.interrupt() + async def interrupt(self, context: Context) -> None: + assert context.running_process + context.running_process.interrupt() @hookimpl - async def terminate(self) -> None: - self._running.terminate() + async def terminate(self, context: Context) -> None: + assert context.running_process + context.running_process.terminate() @hookimpl - async def kill(self) -> None: - self._running.kill() + async def kill(self, context: Context) -> None: + assert context.running_process + context.running_process.kill() class CommandSender: @hookimpl - async def on_start_run(self, send_command: Callable[[Command], None]) -> None: - self._send_command = send_command - - @hookimpl - async def send_command(self, command: Command) -> None: - self._send_command(command) + async def send_command(self, context: Context, command: Command) -> None: + assert context.send_command + context.send_command(command) class Result: @hookimpl - async def on_end_run(self, exited_process: ExitedProcess[RunResult]) -> None: - self._run_result = exited_process.returned or RunResult(ret=None, exc=None) - - @hookimpl - def exception(self) -> Optional[BaseException]: - return self._run_result.exc + def exception(self, context: Context) -> Optional[BaseException]: + if not context.exited_process: + return None + if not context.exited_process.returned: + return None + return context.exited_process.returned.exc @hookimpl - def result(self) -> Any: - return self._run_result.result() + def result(self, context: Context) -> Any: + if not context.exited_process: + return None + if not context.exited_process.returned: + return None + return context.exited_process.returned.result() diff --git a/nextline/plugin/plugins/session/spawn.py b/nextline/plugin/plugins/session/spawn.py deleted file mode 100644 index 72283634..00000000 --- a/nextline/plugin/plugins/session/spawn.py +++ /dev/null @@ -1,63 +0,0 @@ -import multiprocessing as mp -from collections.abc import AsyncIterator, Callable -from concurrent.futures import ProcessPoolExecutor -from contextlib import asynccontextmanager -from functools import partial -from logging import getLogger -from typing import cast - -from tblib import pickling_support - -from nextline import spawned -from nextline.plugin.spec import Context -from nextline.spawned import Command, QueueIn, QueueOut, RunResult -from nextline.utils import MultiprocessingLogging, RunningProcess, run_in_process - -from .monitor import relay_queue - -pickling_support.install() - - -def _call_all(*funcs: Callable) -> None: - '''Execute callables and ignore return values. - - Used to call multiple initializers in ProcessPoolExecutor. - ''' - for func in funcs: - func() - - -@asynccontextmanager -async def run_session( - context: Context, -) -> AsyncIterator[tuple[RunningProcess[RunResult], Callable[[Command], None]]]: - assert context.run_arg - mp_context = mp.get_context('spawn') - queue_in = cast(QueueIn, mp_context.Queue()) - queue_out = cast(QueueOut, mp_context.Queue()) - send_command = SendCommand(queue_in) - async with MultiprocessingLogging(mp_context=mp_context) as mp_logging: - initializer = partial( - _call_all, - mp_logging.initializer, - partial(spawned.set_queues, queue_in, queue_out), - ) - executor_factory = partial( - ProcessPoolExecutor, - max_workers=1, - mp_context=mp_context, - initializer=initializer, - ) - func = partial(spawned.main, context.run_arg) - async with relay_queue(context, queue_out): - running = await run_in_process(func, executor_factory) - yield running, send_command - - -def SendCommand(queue_in: QueueIn) -> Callable[[Command], None]: - def _send_command(command: Command) -> None: - logger = getLogger(__name__) - logger.debug(f'send_pdb_command({command!r}') - queue_in.put(command) - - return _send_command diff --git a/nextline/plugin/spec.py b/nextline/plugin/spec.py index 205fd14f..4bad5c08 100644 --- a/nextline/plugin/spec.py +++ b/nextline/plugin/spec.py @@ -3,7 +3,7 @@ import apluggy -from nextline import spawned +from nextline import events, spawned from nextline.types import InitOptions, ResetOptions from nextline.utils import ExitedProcess, RunningProcess from nextline.utils.pubsub.broker import PubSub @@ -24,6 +24,9 @@ class Context: hook: apluggy.PluginManager pubsub: PubSub run_arg: spawned.RunArg | None = None + send_command: Callable[[spawned.Command], None] | None = None + running_process: RunningProcess[spawned.RunResult] | None = None + exited_process: ExitedProcess[spawned.RunResult] | None = None @hookspec @@ -75,11 +78,12 @@ async def run(context: Context): # type: ignore @hookspec -async def on_start_run( - context: Context, - running_process: RunningProcess[spawned.RunResult], - send_command: Callable[[spawned.Command], None], -) -> None: +async def on_event_in_process(context: Context, event: events.Event) -> None: + pass + + +@hookspec +async def on_start_run(context: Context) -> None: pass @@ -104,9 +108,7 @@ async def send_command(context: Context, command: spawned.Command) -> None: @hookspec -async def on_end_run( - context: Context, exited_process: ExitedProcess[spawned.RunResult] -) -> None: +async def on_end_run(context: Context) -> None: pass @@ -121,47 +123,47 @@ def result(context: Context) -> Any: @hookspec -async def on_start_trace(context: Context, event: spawned.OnStartTrace) -> None: +async def on_start_trace(context: Context, event: events.OnStartTrace) -> None: pass @hookspec -async def on_end_trace(context: Context, event: spawned.OnEndTrace) -> None: +async def on_end_trace(context: Context, event: events.OnEndTrace) -> None: pass @hookspec async def on_start_trace_call( - context: Context, event: spawned.OnStartTraceCall + context: Context, event: events.OnStartTraceCall ) -> None: pass @hookspec -async def on_end_trace_call(context: Context, event: spawned.OnEndTraceCall) -> None: +async def on_end_trace_call(context: Context, event: events.OnEndTraceCall) -> None: pass @hookspec -async def on_start_cmdloop(context: Context, event: spawned.OnStartCmdloop) -> None: +async def on_start_cmdloop(context: Context, event: events.OnStartCmdloop) -> None: pass @hookspec -async def on_end_cmdloop(context: Context, event: spawned.OnEndCmdloop) -> None: +async def on_end_cmdloop(context: Context, event: events.OnEndCmdloop) -> None: pass @hookspec -async def on_start_prompt(context: Context, event: spawned.OnStartPrompt) -> None: +async def on_start_prompt(context: Context, event: events.OnStartPrompt) -> None: pass @hookspec -async def on_end_prompt(context: Context, event: spawned.OnEndPrompt) -> None: +async def on_end_prompt(context: Context, event: events.OnEndPrompt) -> None: pass @hookspec -async def on_write_stdout(context: Context, event: spawned.OnWriteStdout) -> None: +async def on_write_stdout(context: Context, event: events.OnWriteStdout) -> None: pass diff --git a/nextline/spawned/__init__.py b/nextline/spawned/__init__.py index 0260e784..048a164a 100644 --- a/nextline/spawned/__init__.py +++ b/nextline/spawned/__init__.py @@ -26,18 +26,6 @@ import traceback from .commands import Command, PdbCommand -from .events import ( - Event, - OnEndCmdloop, - OnEndPrompt, - OnEndTrace, - OnEndTraceCall, - OnStartCmdloop, - OnStartPrompt, - OnStartTrace, - OnStartTraceCall, - OnWriteStdout, -) from .runner import run from .types import QueueIn, QueueOut, RunArg, RunResult, Statement diff --git a/nextline/spawned/path.py b/nextline/spawned/path.py new file mode 100644 index 00000000..574dc745 --- /dev/null +++ b/nextline/spawned/path.py @@ -0,0 +1,24 @@ +import os +from typing import Callable + + +def ToCanonicPath() -> Callable[[str], str]: + # Based on Bdb.canonic() + # https://github.com/python/cpython/blob/v3.10.5/Lib/bdb.py#L39-L54 + + cache = dict[str, str]() + + def _to_canonic_path(filename: str) -> str: + if filename == "<" + filename[1:-1] + ">": + return filename + canonic = cache.get(filename) + if not canonic: + canonic = os.path.abspath(filename) + canonic = os.path.normcase(canonic) + cache[filename] = canonic + return canonic + + return _to_canonic_path + + +to_canonic_path = ToCanonicPath() diff --git a/nextline/spawned/plugin/plugins/compose.py b/nextline/spawned/plugin/plugins/compose.py index 5f54d919..645c2534 100644 --- a/nextline/spawned/plugin/plugins/compose.py +++ b/nextline/spawned/plugin/plugins/compose.py @@ -3,9 +3,9 @@ from types import CodeType from typing import Any, Callable +from nextline.spawned.path import to_canonic_path from nextline.spawned.plugin.spec import hookimpl from nextline.spawned.types import RunArg, RunResult -from nextline.spawned.utils import to_canonic_path from . import _script diff --git a/nextline/spawned/plugin/plugins/local_.py b/nextline/spawned/plugin/plugins/local_.py index b5a5dc8c..f7786110 100644 --- a/nextline/spawned/plugin/plugins/local_.py +++ b/nextline/spawned/plugin/plugins/local_.py @@ -7,7 +7,7 @@ from exceptiongroup import catch from nextline.spawned.plugin.spec import hookimpl -from nextline.spawned.types import RunResult, TraceArgs, TraceFunction +from nextline.spawned.types import RunResult, TraceArgs, TraceCallInfo, TraceFunction from nextline.spawned.utils import WithContext from nextline.types import TraceNo @@ -65,7 +65,10 @@ def _keyboard_interrupt(exc: BaseException) -> None: nonlocal keyboard_interrupt_raised keyboard_interrupt_raised = True - with hook.with_.on_trace_call(trace_args=(frame, event, arg)): + trace_args = (frame, event, arg) + trace_call_info = TraceCallInfo(args=trace_args) + + with hook.with_.on_trace_call(trace_call_info=trace_call_info): with catch({KeyboardInterrupt: _keyboard_interrupt}): # TODO: Using exceptiongroup.catch() for Python 3.10. # Rewrite with except* for Python 3.11. @@ -83,16 +86,16 @@ def _keyboard_interrupt(exc: BaseException) -> None: class TraceCallHandler: - '''A plugin that keeps the trace call arguments during trace calls. + '''A plugin that keeps the trace call info during trace calls. - This plugin collect the trace call arguments when the context manager hook + This plugin collect the trace call info when the context manager hook `on_trace_call` is entered. It responds to the first result only hooks - `is_on_trace_call` and `current_trace_args`. + `is_on_trace_call`, `current_trace_args`, and `current_trace_call_info`. ''' def __init__(self) -> None: - self._trace_args_map = dict[TraceNo, TraceArgs]() self._traces_on_call = set[TraceNo]() + self._info_map = dict[TraceNo, TraceCallInfo]() @hookimpl def init(self, hook: PluginManager) -> None: @@ -100,15 +103,15 @@ def init(self, hook: PluginManager) -> None: @hookimpl @contextmanager - def on_trace_call(self, trace_args: TraceArgs) -> Iterator[None]: + def on_trace_call(self, trace_call_info: TraceCallInfo) -> Iterator[None]: trace_no = self._hook.hook.current_trace_no() self._traces_on_call.add(trace_no) - self._trace_args_map[trace_no] = trace_args + self._info_map[trace_no] = trace_call_info try: yield finally: self._traces_on_call.remove(trace_no) - del self._trace_args_map[trace_no] + del self._info_map[trace_no] @hookimpl def is_on_trace_call(self) -> Optional[bool]: @@ -118,4 +121,12 @@ def is_on_trace_call(self) -> Optional[bool]: @hookimpl def current_trace_args(self) -> Optional[TraceArgs]: trace_no = self._hook.hook.current_trace_no() - return self._trace_args_map.get(trace_no) + info = self._info_map.get(trace_no) + if info is None: + return None + return info.args + + @hookimpl + def current_trace_call_info(self) -> Optional[TraceCallInfo]: + trace_no = self._hook.hook.current_trace_no() + return self._info_map.get(trace_no) diff --git a/nextline/spawned/plugin/plugins/repeat.py b/nextline/spawned/plugin/plugins/repeat.py index e0b41133..134f4693 100644 --- a/nextline/spawned/plugin/plugins/repeat.py +++ b/nextline/spawned/plugin/plugins/repeat.py @@ -3,7 +3,7 @@ from apluggy import PluginManager, contextmanager -from nextline.spawned.events import ( +from nextline.events import ( OnEndCmdloop, OnEndPrompt, OnEndTrace, @@ -15,15 +15,15 @@ OnWriteStdout, ) from nextline.spawned.plugin.spec import hookimpl -from nextline.spawned.types import QueueOut, TraceArgs -from nextline.spawned.utils import to_canonic_path +from nextline.spawned.types import QueueOut, RunArg, TraceCallInfo from nextline.types import PromptNo, TraceNo class Repeater: @hookimpl - def init(self, hook: PluginManager, queue_out: QueueOut) -> None: + def init(self, hook: PluginManager, run_arg: RunArg, queue_out: QueueOut) -> None: self._hook = hook + self._run_no = run_arg.run_no self._queue_out = queue_out @hookimpl @@ -33,6 +33,7 @@ def on_start_trace(self, trace_no: TraceNo) -> None: thread_no = self._hook.hook.current_thread_no() task_no = self._hook.hook.current_task_no() event = OnStartTrace( + run_no=self._run_no, started_at=started_at, trace_no=trace_no, thread_no=thread_no, @@ -43,24 +44,22 @@ def on_start_trace(self, trace_no: TraceNo) -> None: @hookimpl def on_end_trace(self, trace_no: TraceNo) -> None: ended_at = datetime.datetime.utcnow() - event = OnEndTrace(ended_at=ended_at, trace_no=trace_no) + event = OnEndTrace(ended_at=ended_at, run_no=self._run_no, trace_no=trace_no) self._queue_out.put(event) @hookimpl @contextmanager - def on_trace_call(self, trace_args: TraceArgs) -> Iterator[None]: + def on_trace_call(self, trace_call_info: TraceCallInfo) -> Iterator[None]: started_at = datetime.datetime.utcnow() trace_no = self._hook.hook.current_trace_no() - frame, call_event, call_arg = trace_args - file_name = to_canonic_path(frame.f_code.co_filename) - line_no = frame.f_lineno event_start = OnStartTraceCall( started_at=started_at, + run_no=self._run_no, trace_no=trace_no, - file_name=file_name, - line_no=line_no, - frame_object_id=id(frame), - call_event=call_event, + file_name=trace_call_info.file_name, + line_no=trace_call_info.line_no, + frame_object_id=trace_call_info.frame_object_id, + event=trace_call_info.event, ) self._queue_out.put(event_start) @@ -68,7 +67,9 @@ def on_trace_call(self, trace_args: TraceArgs) -> Iterator[None]: yield finally: ended_at = datetime.datetime.utcnow() - event_end = OnEndTraceCall(ended_at=ended_at, trace_no=trace_no) + event_end = OnEndTraceCall( + ended_at=ended_at, run_no=self._run_no, trace_no=trace_no + ) self._queue_out.put(event_end) @hookimpl @@ -76,26 +77,36 @@ def on_trace_call(self, trace_args: TraceArgs) -> Iterator[None]: def on_cmdloop(self) -> Generator[None, None, None]: started_at = datetime.datetime.utcnow() trace_no = self._hook.hook.current_trace_no() - event_start = OnStartCmdloop(started_at=started_at, trace_no=trace_no) + event_start = OnStartCmdloop( + started_at=started_at, run_no=self._run_no, trace_no=trace_no + ) self._queue_out.put(event_start) try: yield finally: ended_at = datetime.datetime.utcnow() - event_end = OnEndCmdloop(ended_at=ended_at, trace_no=trace_no) + event_end = OnEndCmdloop( + ended_at=ended_at, run_no=self._run_no, trace_no=trace_no + ) self._queue_out.put(event_end) @hookimpl @contextmanager def on_prompt(self, prompt_no: PromptNo, text: str) -> Generator[None, str, None]: started_at = datetime.datetime.utcnow() - trace_no = self._hook.hook.current_trace_no() + trace_no: TraceNo = self._hook.hook.current_trace_no() + trace_call_info: TraceCallInfo = self._hook.hook.current_trace_call_info() event_start = OnStartPrompt( started_at=started_at, + run_no=self._run_no, trace_no=trace_no, prompt_no=prompt_no, prompt_text=text, + file_name=trace_call_info.file_name, + line_no=trace_call_info.line_no, + frame_object_id=trace_call_info.frame_object_id, + event=trace_call_info.event, ) self._queue_out.put(event_start) @@ -108,6 +119,7 @@ def on_prompt(self, prompt_no: PromptNo, text: str) -> Generator[None, str, None ended_at = datetime.datetime.utcnow() event_end = OnEndPrompt( ended_at=ended_at, + run_no=self._run_no, trace_no=trace_no, prompt_no=prompt_no, command=command, @@ -120,6 +132,7 @@ def on_write_stdout(self, trace_no: TraceNo, line: str) -> None: trace_no = self._hook.hook.current_trace_no() event = OnWriteStdout( written_at=written_at, + run_no=self._run_no, trace_no=trace_no, text=line, ) diff --git a/nextline/spawned/plugin/spec.py b/nextline/spawned/plugin/spec.py index 17cd589f..27f14a3f 100644 --- a/nextline/spawned/plugin/spec.py +++ b/nextline/spawned/plugin/spec.py @@ -1,5 +1,5 @@ from asyncio import Task -from collections.abc import Callable, Collection, Generator +from collections.abc import Callable, Collection, Generator, Iterator from threading import Thread from types import FrameType from typing import Any, Optional @@ -12,6 +12,7 @@ RunArg, RunResult, TraceArgs, + TraceCallInfo, TraceFunction, ) from nextline.types import PromptNo, TaskNo, ThreadNo, TraceNo @@ -123,8 +124,8 @@ def on_end_trace(trace_no: TraceNo) -> None: @hookspec @apluggy.contextmanager -def on_trace_call(trace_args: TraceArgs): # type: ignore - pass +def on_trace_call(trace_call_info: TraceCallInfo) -> Iterator[None]: + yield @hookspec(firstresult=True) @@ -137,6 +138,11 @@ def current_trace_args() -> Optional[TraceArgs]: pass +@hookspec(firstresult=True) +def current_trace_call_info() -> Optional[TraceCallInfo]: + pass + + @hookspec @apluggy.contextmanager def on_cmdloop(): # type: ignore diff --git a/nextline/spawned/runner.py b/nextline/spawned/runner.py index 3d4b2026..92bb5908 100644 --- a/nextline/spawned/runner.py +++ b/nextline/spawned/runner.py @@ -22,15 +22,15 @@ def _compile_and_run(hook: PluginManager, run_arg: RunArg) -> RunResult: func = hook.hook.compose_callable() except BaseException as exc: _remove_frame(exc=exc, frame=inspect.currentframe()) - return RunResult(ret=None, exc=exc) + return RunResult(exc=exc) trace_func = hook.hook.create_trace_func() try: with sys_trace(trace_func=trace_func, thread=run_arg.trace_threads): ret = func() - return RunResult(ret=ret, exc=None) + return RunResult(ret=ret) except BaseException as exc: _remove_frame(exc=exc, frame=inspect.currentframe()) - return RunResult(ret=None, exc=exc) + return RunResult(exc=exc) def _remove_frame(exc: BaseException, frame: Optional[FrameType]) -> None: diff --git a/nextline/spawned/types.py b/nextline/spawned/types.py index 26886b2b..f553293b 100644 --- a/nextline/spawned/types.py +++ b/nextline/spawned/types.py @@ -5,10 +5,11 @@ from types import FrameType from typing import Any, Callable, Optional +from nextline.events import Event +from nextline.spawned.path import to_canonic_path from nextline.types import RunNo, Statement from .commands import Command -from .events import Event # if TYPE_CHECKING: # from sys import TraceFunction as TraceFunc # type: ignore # noqa: F401 @@ -33,8 +34,8 @@ class RunArg: @dataclass class RunResult: - ret: Optional[Any] - exc: Optional[BaseException] + ret: Optional[Any] = None + exc: Optional[BaseException] = None _fmt_ret: Optional[str] = field(init=False, repr=False, default=None) _fmt_exc: Optional[str] = field(init=False, repr=False, default=None) @@ -64,3 +65,19 @@ def result(self) -> Any: # TODO: add a test for the exception raise self.exc return self.ret + + +@dataclass +class TraceCallInfo: + args: TraceArgs + file_name: str = field(init=False) + line_no: int = field(init=False) + frame_object_id: int = field(init=False) + event: str = field(init=False) + + def __post_init__(self) -> None: + frame, event, _ = self.args + self.file_name = to_canonic_path(frame.f_code.co_filename) + self.line_no = frame.f_lineno + self.frame_object_id = id(frame) + self.event = event diff --git a/nextline/spawned/utils.py b/nextline/spawned/utils.py index 55fcd9ce..87e44f0e 100644 --- a/nextline/spawned/utils.py +++ b/nextline/spawned/utils.py @@ -1,4 +1,3 @@ -import os from types import FrameType from typing import Any, Callable, ContextManager, Optional @@ -31,25 +30,3 @@ def _global_trace( return _create_local_trace()(frame, event, arg) return _global_trace - - -def ToCanonicPath() -> Callable[[str], str]: - # Based on Bdb.canonic() - # https://github.com/python/cpython/blob/v3.10.5/Lib/bdb.py#L39-L54 - - cache = dict[str, str]() - - def _to_canonic_path(filename: str) -> str: - if filename == "<" + filename[1:-1] + ">": - return filename - canonic = cache.get(filename) - if not canonic: - canonic = os.path.abspath(filename) - canonic = os.path.normcase(canonic) - cache[filename] = canonic - return canonic - - return _to_canonic_path - - -to_canonic_path = ToCanonicPath() diff --git a/nextline/utils/__init__.py b/nextline/utils/__init__.py index 9834bd12..47bed802 100644 --- a/nextline/utils/__init__.py +++ b/nextline/utils/__init__.py @@ -13,7 +13,6 @@ 'profile_func', 'PubSub', 'PubSubItem', - 'ExecutorFactory', 'ExitedProcess', 'RunningProcess', 'run_in_process', @@ -28,6 +27,6 @@ from .peek import peek_stderr, peek_stdout, peek_textio from .profile import profile_func from .pubsub import PubSub, PubSubItem -from .run import ExecutorFactory, ExitedProcess, RunningProcess, run_in_process +from .run import ExitedProcess, RunningProcess, run_in_process from .thread_exception import ExcThread from .thread_task_id import ThreadTaskIdComposer diff --git a/nextline/utils/multiprocessing_logging.py b/nextline/utils/multiprocessing_logging.py index f3ab18fa..6ad80de8 100644 --- a/nextline/utils/multiprocessing_logging.py +++ b/nextline/utils/multiprocessing_logging.py @@ -1,4 +1,5 @@ import asyncio +import contextlib import logging import multiprocessing as mp from functools import partial @@ -6,7 +7,7 @@ from logging.handlers import QueueHandler from multiprocessing.context import BaseContext from queue import Queue -from typing import Callable, Optional, cast +from typing import Optional, cast __all__ = ['MultiprocessingLogging'] @@ -17,7 +18,8 @@ def example_func() -> None: logger.warning('warning from another process') -class MultiprocessingLogging: +@contextlib.asynccontextmanager +async def MultiprocessingLogging(mp_context: Optional[BaseContext] = None): '''Collect logging from other processes in the main process. Example: @@ -31,10 +33,10 @@ class MultiprocessingLogging: ... from concurrent.futures import ProcessPoolExecutor ... ... # Start MultiprocessingLogging. - ... async with MultiprocessingLogging() as mp_logging: + ... async with MultiprocessingLogging() as initializer: ... ... # The initializer is given to ProcessPoolExecutor. - ... with ProcessPoolExecutor(initializer=mp_logging.initializer) as executor: + ... with ProcessPoolExecutor(initializer=initializer) as executor: ... ... # In another process, execute example_func(), which logs a warning. ... future = executor.submit(example_func) @@ -73,40 +75,24 @@ class MultiprocessingLogging: ''' - def __init__(self, mp_context: Optional[BaseContext] = None) -> None: - mp_context = mp_context or mp.get_context() - self._q = cast(Queue[LogRecord | None], mp_context.Queue()) - self._initializer = partial(_initializer, self._q) - self._task: asyncio.Task | None = None + mp_context = mp_context or mp.get_context() + queue = cast(Queue[LogRecord | None], mp_context.Queue()) + initializer = partial(_initializer, queue) - @property - def initializer(self) -> Callable[[], None]: - '''A callable with no args to be given to ProcessPoolExecutor as initializer.''' - return self._initializer - - async def open(self) -> None: - self._task = asyncio.create_task(self._listen()) - - async def _listen(self) -> None: - '''Receive loggings from other processes and handle them in the main process.''' - while (record := await asyncio.to_thread(self._q.get)) is not None: + async def _listen() -> None: + '''Receive loggings from other processes and handle them in the current process.''' + while (record := await asyncio.to_thread(queue.get)) is not None: logger = getLogger(record.name) if logger.getEffectiveLevel() <= record.levelno: logger.handle(record) - async def close(self) -> None: - if self._task: - await asyncio.to_thread(self._q.put, None) - await self._task - self._task = None - - async def __aenter__(self) -> 'MultiprocessingLogging': - await self.open() - return self + task = asyncio.create_task(_listen()) - async def __aexit__(self, exc_type, exc_value, traceback) -> None: # type: ignore - del exc_type, exc_value, traceback - await self.close() + try: + yield initializer + finally: + await asyncio.to_thread(queue.put, None) + await task def _initializer(queue: Queue[LogRecord]) -> None: diff --git a/nextline/utils/run.py b/nextline/utils/run.py index 9d3c6af9..636ca7ae 100644 --- a/nextline/utils/run.py +++ b/nextline/utils/run.py @@ -1,4 +1,5 @@ import asyncio +import contextlib import os import signal from collections.abc import Callable, Generator @@ -9,7 +10,10 @@ from functools import partial from logging import getLogger from multiprocessing import Process -from typing import Generic, TypeVar +from multiprocessing.context import BaseContext +from typing import Any, Generic, TypeVar + +from .multiprocessing_logging import MultiprocessingLogging _T = TypeVar("_T") @@ -93,11 +97,22 @@ def __await__(self) -> Generator[None, None, ExitedProcess[_T]]: ) -ExecutorFactory = Callable[[], ProcessPoolExecutor] +def _call_all(*funcs: Callable[[], Any] | None) -> None: + '''Execute callables and ignore return values. + + Used to call multiple initializers in ProcessPoolExecutor. + ''' + for func in funcs: + if func is None: + continue + func() async def run_in_process( - func: Callable[[], _T], executor_factory: ExecutorFactory | None = None + func: Callable[[], _T], + mp_context: BaseContext | None = None, + initializer: Callable[[], None] | None = None, + collect_logging: bool = False, ) -> RunningProcess[_T]: '''Call a function in a separate process and return an awaitable. @@ -123,31 +138,38 @@ async def run_in_process( ''' - if executor_factory is None: - executor_factory = partial(ProcessPoolExecutor, max_workers=1) - process: Process | None = None event = asyncio.Event() async def _run() -> tuple[_T | None, BaseException | None]: nonlocal process - - with executor_factory() as executor: - loop = asyncio.get_running_loop() - future = loop.run_in_executor(executor, func) - process = list(executor._processes.values())[0] - - event.set() - ret = None - exc = None - try: - ret = await future - except BrokenProcessPool: - # NOTE: Not possible to use "as" for unknown reason. - pass - except BaseException as e: - exc = e - return ret, exc + nonlocal initializer + + async with contextlib.AsyncExitStack() as stack: + if collect_logging: + logging_initializer = await stack.enter_async_context( + MultiprocessingLogging(mp_context=mp_context) + ) + initializer = partial(_call_all, logging_initializer, initializer) + + with ProcessPoolExecutor( + max_workers=1, mp_context=mp_context, initializer=initializer + ) as executor: + loop = asyncio.get_running_loop() + future = loop.run_in_executor(executor, func) + process = list(executor._processes.values())[0] + + event.set() + ret = None + exc = None + try: + ret = await future + except BrokenProcessPool: + # NOTE: Not possible to use "as" for unknown reason. + pass + except BaseException as e: + exc = e + return ret, exc task = asyncio.create_task(_run()) await event.wait() diff --git a/tests/main/test_register.py b/tests/main/test_register.py index 22920c38..5c31404f 100644 --- a/tests/main/test_register.py +++ b/tests/main/test_register.py @@ -1,8 +1,8 @@ import time from nextline import Nextline +from nextline.events import OnStartPrompt from nextline.plugin.spec import Context, hookimpl -from nextline.spawned import OnStartPrompt def func(): diff --git a/tests/spawned/run/test_run.py b/tests/spawned/run/test_run.py index c68b6f7c..6d3c0119 100644 --- a/tests/spawned/run/test_run.py +++ b/tests/spawned/run/test_run.py @@ -4,15 +4,8 @@ import pytest -from nextline.spawned import ( - OnStartPrompt, - PdbCommand, - QueueIn, - QueueOut, - RunArg, - main, - set_queues, -) +from nextline.events import OnStartPrompt +from nextline.spawned import PdbCommand, QueueIn, QueueOut, RunArg, main, set_queues from nextline.types import RunNo diff --git a/tests/utils/run/test_run.py b/tests/utils/run/test_run.py index 6a6c3e5d..8dec0fa5 100644 --- a/tests/utils/run/test_run.py +++ b/tests/utils/run/test_run.py @@ -1,18 +1,15 @@ -from concurrent.futures import ProcessPoolExecutor -from functools import partial from typing import NoReturn -import pytest -from nextline.utils import ExecutorFactory, run_in_process +from nextline.utils import run_in_process def func_str() -> str: return 'foo' -async def test_success(executor_factory: ExecutorFactory) -> None: - running = await run_in_process(func_str, executor_factory) +async def test_success() -> None: + running = await run_in_process(func_str) assert running.process assert running.process_created_at result = await running @@ -28,8 +25,8 @@ async def test_default_executor() -> None: assert 'foo' == result.returned -async def test_repr(executor_factory: ExecutorFactory) -> None: - running = await run_in_process(func_str, executor_factory) +async def test_repr() -> None: + running = await run_in_process(func_str) repr(running) await running @@ -43,14 +40,9 @@ def func_raise() -> NoReturn: raise MockError() -async def test_error(executor_factory: ExecutorFactory) -> None: - running = await run_in_process(func_raise, executor_factory) +async def test_error() -> None: + running = await run_in_process(func_raise) result = await running assert running.process assert running.process.exitcode == 0 assert isinstance(result.raised, MockError) - - -@pytest.fixture -def executor_factory() -> ExecutorFactory: - return partial(ProcessPoolExecutor, max_workers=1) diff --git a/tests/utils/run/test_signal.py b/tests/utils/run/test_signal.py index 05a0071e..758cde04 100644 --- a/tests/utils/run/test_signal.py +++ b/tests/utils/run/test_signal.py @@ -2,7 +2,6 @@ import signal import time -from concurrent.futures import ProcessPoolExecutor from functools import partial from multiprocessing import Event from types import FrameType @@ -10,7 +9,7 @@ import pytest -from nextline.utils import ExecutorFactory, run_in_process +from nextline.utils import run_in_process if TYPE_CHECKING: from multiprocessing.synchronize import Event as _EventType @@ -35,8 +34,8 @@ def func_sleep() -> NoReturn: raise RuntimeError("to be terminated by here") -async def test_interrupt(executor_factory: ExecutorFactory, event: _EventType) -> None: - running = await run_in_process(func_sleep, executor_factory) +async def test_interrupt(event: _EventType) -> None: + running = await run_in_process(func_sleep, initializer=partial(initializer, event)) event.wait() running.interrupt() result = await running @@ -45,8 +44,8 @@ async def test_interrupt(executor_factory: ExecutorFactory, event: _EventType) - assert isinstance(result.raised, KeyboardInterrupt) -async def test_terminate(executor_factory: ExecutorFactory, event: _EventType) -> None: - running = await run_in_process(func_sleep, executor_factory) +async def test_terminate(event: _EventType) -> None: + running = await run_in_process(func_sleep, initializer=partial(initializer, event)) event.wait() running.terminate() result = await running @@ -55,8 +54,8 @@ async def test_terminate(executor_factory: ExecutorFactory, event: _EventType) - assert result -async def test_kill(executor_factory: ExecutorFactory, event: _EventType) -> None: - running = await run_in_process(func_sleep, executor_factory) +async def test_kill(event: _EventType) -> None: + running = await run_in_process(func_sleep, initializer=partial(initializer, event)) event.wait() running.kill() result = await running @@ -78,10 +77,10 @@ def func_catch_interrupt() -> str: return "bar" -async def test_interrupt_catch( - executor_factory: ExecutorFactory, event: _EventType -) -> None: - running = await run_in_process(func_catch_interrupt, executor_factory) +async def test_interrupt_catch(event: _EventType) -> None: + running = await run_in_process( + func_catch_interrupt, initializer=partial(initializer, event) + ) event.wait() running.interrupt() result = await running @@ -112,10 +111,10 @@ def func_handle_terminate() -> str: return "bar" -async def test_terminate_handle( - executor_factory: ExecutorFactory, event: _EventType -) -> None: - running = await run_in_process(func_handle_terminate, executor_factory) +async def test_terminate_handle(event: _EventType) -> None: + running = await run_in_process( + func_handle_terminate, initializer=partial(initializer, event) + ) event.wait() running.terminate() result = await running @@ -124,16 +123,6 @@ async def test_terminate_handle( assert 'foo' == result.returned -@pytest.fixture -def executor_factory(event: _EventType) -> ExecutorFactory: - return partial( - ProcessPoolExecutor, - max_workers=1, - initializer=initializer, - initargs=(event,), - ) - - @pytest.fixture def event() -> _EventType: return Event() diff --git a/tests/utils/test_multiprocessing_logging.py b/tests/utils/test_multiprocessing_logging.py index 7eeb2b9d..871cdd23 100644 --- a/tests/utils/test_multiprocessing_logging.py +++ b/tests/utils/test_multiprocessing_logging.py @@ -1,4 +1,3 @@ -import asyncio import logging import multiprocessing as mp from concurrent.futures import ProcessPoolExecutor @@ -9,18 +8,6 @@ from nextline.utils import MultiprocessingLogging -def test_init_sync(): - '''Assert the init without the running loop.''' - with pytest.raises(RuntimeError): - asyncio.get_running_loop() - assert MultiprocessingLogging() - - -async def test_close_without_open(): - mp_logging = MultiprocessingLogging() - await mp_logging.close() - - @pytest.mark.parametrize('mp_method', [None, 'spawn', 'fork', 'forkserver']) async def test_multiprocessing_logging( mp_method: str | None, caplog: LogCaptureFixture @@ -28,9 +15,9 @@ async def test_multiprocessing_logging( mp_context = mp.get_context(mp_method) if mp_method else None with caplog.at_level(logging.DEBUG): - async with MultiprocessingLogging(mp_context=mp_context) as mp_logging: + async with MultiprocessingLogging(mp_context=mp_context) as initializer: with ProcessPoolExecutor( - mp_context=mp_context, initializer=mp_logging.initializer + mp_context=mp_context, initializer=initializer ) as executor: fut = executor.submit(fn) assert "foo" == fut.result()