Skip to content

Commit

Permalink
Include RunArg to Context
Browse files Browse the repository at this point in the history
  • Loading branch information
TaiSakuma committed Jan 19, 2024
1 parent aee0a72 commit 96f0ae7
Show file tree
Hide file tree
Showing 11 changed files with 41 additions and 75 deletions.
7 changes: 3 additions & 4 deletions nextline/fsm/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,8 @@ async def on_exit_created(self, _: EventData) -> None:
await self._hook.ahook.start(context=self._context)

async def on_enter_initialized(self, _: EventData) -> None:
self._run_arg = self._hook.hook.compose_run_arg(context=self._context)
await self._hook.ahook.on_initialize_run(
context=self._context, run_arg=self._run_arg
)
self._context.run_arg = self._hook.hook.compose_run_arg(context=self._context)
await self._hook.ahook.on_initialize_run(context=self._context)

async def on_enter_running(self, _: EventData) -> None:
self.run_finished = asyncio.Event()
Expand All @@ -49,6 +47,7 @@ async def on_enter_running(self, _: EventData) -> None:
async def run() -> None:
async with self._hook.awith.run(context=self._context):
run_started.set()
self._context.run_arg = None
await self.finish() # type: ignore
self.run_finished.set()

Expand Down
22 changes: 8 additions & 14 deletions nextline/plugin/plugins/registrars/prompt_info.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import dataclasses
from logging import getLogger
from typing import Optional

from nextline.plugin.spec import Context, hookimpl
from nextline.spawned import (
Expand All @@ -11,14 +10,12 @@
OnStartPrompt,
OnStartTrace,
OnStartTraceCall,
RunArg,
)
from nextline.types import PromptInfo, PromptNo, RunNo, TraceNo
from nextline.types import PromptInfo, PromptNo, TraceNo


class PromptInfoRegistrar:
def __init__(self) -> None:
self._run_no: Optional[RunNo] = None
self._last_prompt_frame_map = dict[TraceNo, int]()
self._trace_call_map = dict[TraceNo, OnStartTraceCall]()
self._prompt_info_map = dict[PromptNo, PromptInfo]()
Expand All @@ -31,8 +28,7 @@ async def start(self) -> None:
pass

@hookimpl
async def on_initialize_run(self, run_arg: RunArg) -> None:
self._run_no = run_arg.run_no
async def on_initialize_run(self) -> None:
self._last_prompt_frame_map.clear()
self._trace_call_map.clear()
self._prompt_info_map.clear()
Expand All @@ -46,17 +42,15 @@ async def on_end_run(self, context: Context) -> None:
key = self._keys.pop()
await context.pubsub.end(key)

self._run_no = None

@hookimpl
async def on_start_trace(self, context: Context, event: OnStartTrace) -> None:
assert self._run_no is not None
assert context.run_arg
trace_no = event.trace_no

# TODO: Putting a prompt info for now because otherwise tests get stuck
# sometimes for an unknown reason. Need to investigate
prompt_info = PromptInfo(
run_no=self._run_no,
run_no=context.run_arg.run_no,
trace_no=trace_no,
prompt_no=PromptNo(-1),
open=False,
Expand All @@ -81,7 +75,7 @@ async def on_start_trace_call(self, event: OnStartTraceCall) -> None:

@hookimpl
async def on_end_trace_call(self, context: Context, event: OnEndTraceCall) -> None:
assert self._run_no is not None
assert context.run_arg
trace_no = event.trace_no
trace_call = self._trace_call_map.pop(event.trace_no, None)
if trace_call is None:
Expand All @@ -97,7 +91,7 @@ async def on_end_trace_call(self, context: Context, event: OnEndTraceCall) -> No
# prompt info.

prompt_info = PromptInfo(
run_no=self._run_no,
run_no=context.run_arg.run_no,
trace_no=trace_no,
prompt_no=PromptNo(-1),
open=False,
Expand All @@ -115,12 +109,12 @@ async def on_end_trace_call(self, context: Context, event: OnEndTraceCall) -> No

@hookimpl
async def on_start_prompt(self, context: Context, event: OnStartPrompt) -> None:
assert self._run_no is not None
assert context.run_arg
trace_no = event.trace_no
prompt_no = event.prompt_no
trace_call = self._trace_call_map[trace_no]
prompt_info = PromptInfo(
run_no=self._run_no,
run_no=context.run_arg.run_no,
trace_no=trace_no,
prompt_no=prompt_no,
open=True,
Expand Down
11 changes: 4 additions & 7 deletions nextline/plugin/plugins/registrars/prompt_notice.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional

from nextline.plugin.spec import Context, hookimpl
from nextline.spawned import OnEndTraceCall, OnStartPrompt, OnStartTraceCall, RunArg
from nextline.spawned import OnEndTraceCall, OnStartPrompt, OnStartTraceCall
from nextline.types import PromptNotice, RunNo, TraceNo


Expand All @@ -13,33 +13,30 @@ def __init__(self) -> None:
self._logger = getLogger(__name__)

@hookimpl
async def on_initialize_run(self, run_arg: RunArg) -> None:
self._run_no = run_arg.run_no
async def on_initialize_run(self) -> None:
self._trace_call_map.clear()

@hookimpl
async def on_end_run(self, context: Context) -> None:
await context.pubsub.end('prompt_notice')
self._run_no = None

@hookimpl
async def on_start_trace_call(self, event: OnStartTraceCall) -> None:
self._trace_call_map[event.trace_no] = event

@hookimpl
async def on_end_trace_call(self, event: OnEndTraceCall) -> None:
assert self._run_no is not None
self._trace_call_map.pop(event.trace_no, None)

@hookimpl
async def on_start_prompt(self, context: Context, event: OnStartPrompt) -> None:
assert self._run_no is not None
assert context.run_arg
trace_no = event.trace_no
prompt_no = event.prompt_no
trace_call = self._trace_call_map[trace_no]
prompt_notice = PromptNotice(
started_at=event.started_at,
run_no=self._run_no,
run_no=context.run_arg.run_no,
trace_no=trace_no,
prompt_no=prompt_no,
prompt_text=event.prompt_text,
Expand Down
11 changes: 4 additions & 7 deletions nextline/plugin/plugins/registrars/run_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@

from nextline import spawned
from nextline.plugin.spec import Context, hookimpl
from nextline.spawned import RunArg
from nextline.types import RunInfo, RunNo
from nextline.types import RunInfo
from nextline.utils import ExitedProcess


class RunInfoRegistrar:
def __init__(self) -> None:
self._run_no: Optional[RunNo] = None
self._script: Optional[str] = None
self._run_info: Optional[RunInfo] = None

Expand All @@ -20,10 +18,10 @@ async def on_change_script(self, script: str) -> None:
self._script = script

@hookimpl
async def on_initialize_run(self, context: Context, run_arg: RunArg) -> None:
self._run_no = run_arg.run_no
async def on_initialize_run(self, context: Context) -> None:
assert context.run_arg
self._run_info = RunInfo(
run_no=run_arg.run_no, state='initialized', script=self._script
run_no=context.run_arg.run_no, state='initialized', script=self._script
)
await context.pubsub.publish('run_info', self._run_info)

Expand Down Expand Up @@ -54,4 +52,3 @@ async def on_end_run(
await context.pubsub.publish('run_info', self._run_info)

self._run_info = None
self._run_no = None
6 changes: 3 additions & 3 deletions nextline/plugin/plugins/registrars/run_no.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from nextline.plugin.spec import Context, hookimpl
from nextline.spawned import RunArg


class RunNoRegistrar:
@hookimpl
async def on_initialize_run(self, context: Context, run_arg: RunArg) -> None:
await context.pubsub.publish('run_no', run_arg.run_no)
async def on_initialize_run(self, context: Context) -> None:
assert context.run_arg
await context.pubsub.publish('run_no', context.run_arg.run_no)
15 changes: 3 additions & 12 deletions nextline/plugin/plugins/registrars/stdout.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,19 @@
from typing import Optional

from nextline.plugin.spec import Context, hookimpl
from nextline.spawned import OnWriteStdout, RunArg
from nextline.spawned import OnWriteStdout
from nextline.types import RunNo, StdoutInfo


class StdoutRegistrar:
def __init__(self) -> None:
self._run_no: Optional[RunNo] = None

@hookimpl
async def on_initialize_run(self, run_arg: RunArg) -> None:
self._run_no = run_arg.run_no
self._trace_nos = ()

@hookimpl
async def on_end_run(self) -> None:
self._run_no = None

@hookimpl
async def on_write_stdout(self, context: Context, event: OnWriteStdout) -> None:
assert self._run_no is not None
assert context.run_arg
stdout_info = StdoutInfo(
run_no=self._run_no,
run_no=context.run_arg.run_no,
trace_no=event.trace_no,
text=event.text,
written_at=event.written_at,
Expand Down
14 changes: 5 additions & 9 deletions nextline/plugin/plugins/registrars/trace_info.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
import dataclasses
import datetime
from typing import Optional

from nextline.plugin.spec import Context, hookimpl
from nextline.spawned import OnEndTrace, OnStartTrace, RunArg
from nextline.types import RunNo, TraceInfo, TraceNo
from nextline.spawned import OnEndTrace, OnStartTrace
from nextline.types import TraceInfo, TraceNo


class TraceInfoRegistrar:
def __init__(self) -> None:
self._run_no: Optional[RunNo] = None
self._trace_info_map = dict[TraceNo, TraceInfo]()

@hookimpl
async def on_initialize_run(self, run_arg: RunArg) -> None:
self._run_no = run_arg.run_no
async def on_initialize_run(self) -> None:
self._trace_info_map = {}

@hookimpl
Expand All @@ -28,13 +25,12 @@ async def on_end_run(self, context: Context) -> None:
ended_at=datetime.datetime.utcnow(),
)
await context.pubsub.publish('trace_info', trace_info_end)
self._run_no = None

@hookimpl
async def on_start_trace(self, context: Context, event: OnStartTrace) -> None:
assert self._run_no is not None
assert context.run_arg
trace_info = TraceInfo(
run_no=self._run_no,
run_no=context.run_arg.run_no,
trace_no=event.trace_no,
thread_no=event.thread_no,
task_no=event.task_no,
Expand Down
12 changes: 3 additions & 9 deletions nextline/plugin/plugins/registrars/trace_nos.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,18 @@
from typing import Optional

from nextline.plugin.spec import Context, hookimpl
from nextline.spawned import OnEndTrace, OnStartTrace, RunArg
from nextline.types import RunNo, TraceNo
from nextline.spawned import OnEndTrace, OnStartTrace
from nextline.types import TraceNo


class TraceNumbersRegistrar:
def __init__(self) -> None:
self._run_no: Optional[RunNo] = None
self._trace_nos: tuple[TraceNo, ...] = ()

@hookimpl
async def on_initialize_run(self, run_arg: RunArg) -> None:
self._run_no = run_arg.run_no
async def on_initialize_run(self) -> None:
self._trace_nos = ()

@hookimpl
async def on_end_run(self, context: Context) -> None:
self._run_no = None

self._trace_nos = ()
await context.pubsub.publish('trace_nos', self._trace_nos)

Expand Down
8 changes: 2 additions & 6 deletions nextline/plugin/plugins/session/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tblib import pickling_support

from nextline.plugin.spec import Context, hookimpl
from nextline.spawned import Command, RunArg, RunResult
from nextline.spawned import Command, RunResult
from nextline.utils import ExitedProcess, RunningProcess

from .spawn import run_session
Expand All @@ -15,15 +15,11 @@


class RunSession:
@hookimpl
async def on_initialize_run(self, run_arg: RunArg) -> None:
self._run_arg = run_arg

@hookimpl
@apluggy.asynccontextmanager
async def run(self, context: Context) -> AsyncIterator[None]:
ahook = context.hook.ahook
async with run_session(context, self._run_arg) as (running, send_command):
async with run_session(context) as (running, send_command):
await ahook.on_start_run(
context=context, running_process=running, send_command=send_command
)
Expand Down
7 changes: 4 additions & 3 deletions nextline/plugin/plugins/session/spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from nextline import spawned
from nextline.plugin.spec import Context
from nextline.spawned import Command, QueueIn, QueueOut, RunArg, RunResult
from nextline.spawned import Command, QueueIn, QueueOut, RunResult
from nextline.utils import MultiprocessingLogging, RunningProcess, run_in_process

from .monitor import relay_queue
Expand All @@ -29,8 +29,9 @@ def _call_all(*funcs: Callable) -> None:

@asynccontextmanager
async def run_session(
context: Context, run_arg: RunArg
context: Context,
) -> AsyncIterator[tuple[RunningProcess[RunResult], Callable[[Command], None]]]:
assert context.run_arg
mp_context = mp.get_context('spawn')
queue_in = cast(QueueIn, mp_context.Queue())
queue_out = cast(QueueOut, mp_context.Queue())
Expand All @@ -47,7 +48,7 @@ async def run_session(
mp_context=mp_context,
initializer=initializer,
)
func = partial(spawned.main, run_arg)
func = partial(spawned.main, context.run_arg)
async with relay_queue(context, queue_out):
running = await run_in_process(func, executor_factory)
yield running, send_command
Expand Down
3 changes: 2 additions & 1 deletion nextline/plugin/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class Context:
nextline: 'Nextline'
hook: apluggy.PluginManager
pubsub: PubSub
run_arg: spawned.RunArg | None = None


@hookspec
Expand Down Expand Up @@ -63,7 +64,7 @@ def compose_run_arg(context: Context) -> Optional[spawned.RunArg]:


@hookspec
async def on_initialize_run(context: Context, run_arg: spawned.RunArg) -> None:
async def on_initialize_run(context: Context) -> None:
pass


Expand Down

0 comments on commit 96f0ae7

Please sign in to comment.