Skip to content

Commit

Permalink
applied code review from Alex
Browse files Browse the repository at this point in the history
  • Loading branch information
khsrali committed Jan 9, 2025
1 parent d6176fd commit 1e78d7f
Show file tree
Hide file tree
Showing 19 changed files with 234 additions and 258 deletions.
5 changes: 2 additions & 3 deletions src/aiida/calculations/monitors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@

import tempfile
from pathlib import Path
from typing import Union

from aiida.orm import CalcJobNode
from aiida.transports import AsyncTransport, Transport
from aiida.transports import Transport


def always_kill(node: CalcJobNode, transport: Union['Transport', 'AsyncTransport']) -> str | None:
def always_kill(node: CalcJobNode, transport: Transport) -> str | None:
"""Retrieve and inspect files in working directory of job to determine whether the job should be killed.
This particular implementation is just for demonstration purposes and will kill the job as long as there is a
Expand Down
14 changes: 7 additions & 7 deletions src/aiida/engine/daemon/execmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from aiida.schedulers.datastructures import JobState

if TYPE_CHECKING:
from aiida.transports import AsyncTransport, Transport
from aiida.transports import Transport

REMOTE_WORK_DIRECTORY_LOST_FOUND = 'lost+found'

Expand Down Expand Up @@ -64,7 +64,7 @@ def _find_data_node(inputs: MappingType[str, Any], uuid: str) -> Optional[Node]:

async def upload_calculation(
node: CalcJobNode,
transport: Union['Transport', 'AsyncTransport'],
transport: Transport,
calc_info: CalcInfo,
folder: Folder,
inputs: Optional[MappingType[str, Any]] = None,
Expand Down Expand Up @@ -393,7 +393,7 @@ async def _copy_sandbox_files(logger, node, transport, folder, workdir: Path):
await transport.put_async(folder.get_abs_path(filename), workdir.joinpath(filename))


def submit_calculation(calculation: CalcJobNode, transport: Union['Transport', 'AsyncTransport']) -> str | ExitCode:
def submit_calculation(calculation: CalcJobNode, transport: Transport) -> str | ExitCode:
"""Submit a previously uploaded `CalcJob` to the scheduler.
:param calculation: the instance of CalcJobNode to submit.
Expand Down Expand Up @@ -423,7 +423,7 @@ def submit_calculation(calculation: CalcJobNode, transport: Union['Transport', '
return result


async def stash_calculation(calculation: CalcJobNode, transport: Union['Transport', 'AsyncTransport']) -> None:
async def stash_calculation(calculation: CalcJobNode, transport: Transport) -> None:
"""Stash files from the working directory of a completed calculation to a permanent remote folder.
After a calculation has been completed, optionally stash files from the work directory to a storage location on the
Expand Down Expand Up @@ -489,7 +489,7 @@ async def stash_calculation(calculation: CalcJobNode, transport: Union['Transpor


async def retrieve_calculation(
calculation: CalcJobNode, transport: Union['Transport', 'AsyncTransport'], retrieved_temporary_folder: str
calculation: CalcJobNode, transport: Transport, retrieved_temporary_folder: str
) -> FolderData | None:
"""Retrieve all the files of a completed job calculation using the given transport.
Expand Down Expand Up @@ -554,7 +554,7 @@ async def retrieve_calculation(
return retrieved_files


def kill_calculation(calculation: CalcJobNode, transport: Union['Transport', 'AsyncTransport']) -> None:
def kill_calculation(calculation: CalcJobNode, transport: Transport) -> None:
"""Kill the calculation through the scheduler
:param calculation: the instance of CalcJobNode to kill.
Expand Down Expand Up @@ -589,7 +589,7 @@ def kill_calculation(calculation: CalcJobNode, transport: Union['Transport', 'As

async def retrieve_files_from_list(
calculation: CalcJobNode,
transport: Union['Transport', 'AsyncTransport'],
transport: Transport,
folder: str,
retrieve_list: List[Union[str, Tuple[str, str, int], list]],
) -> None:
Expand Down
9 changes: 3 additions & 6 deletions src/aiida/engine/processes/calcjobs/monitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@
import inspect
import typing as t
from datetime import datetime, timedelta
from typing import Union

from aiida.common.lang import type_check
from aiida.common.log import AIIDA_LOGGER
from aiida.orm import CalcJobNode, Dict
from aiida.plugins import BaseFactory

if t.TYPE_CHECKING:
from aiida.transports import AsyncTransport, Transport
from aiida.transports import Transport

LOGGER = AIIDA_LOGGER.getChild(__name__)

Expand Down Expand Up @@ -123,9 +122,7 @@ def validate(self):
parameters = list(signature.parameters.keys())

if any(required_parameter not in parameters for required_parameter in ('node', 'transport')):
correct_signature = (
"(node: CalcJobNode, transport: Union['Transport', 'AsyncTransport'], **kwargs) str | None:"
)
correct_signature = '(node: CalcJobNode, transport: Transport, **kwargs) str | None:'
raise ValueError(
f'The monitor `{self.entry_point}` has an invalid function signature, it should be: {correct_signature}'
)
Expand Down Expand Up @@ -179,7 +176,7 @@ def monitors(self) -> collections.OrderedDict:
def process(
self,
node: CalcJobNode,
transport: Union['Transport', 'AsyncTransport'],
transport: Transport,
) -> CalcJobMonitorResult | None:
"""Call all monitors in order and return the result as one returns anything other than ``None``.
Expand Down
6 changes: 3 additions & 3 deletions src/aiida/engine/transports.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
import contextvars
import logging
import traceback
from typing import TYPE_CHECKING, Awaitable, Dict, Hashable, Iterator, Optional, Union
from typing import TYPE_CHECKING, Awaitable, Dict, Hashable, Iterator, Optional

from aiida.orm import AuthInfo

if TYPE_CHECKING:
from aiida.transports import AsyncTransport, Transport
from aiida.transports import Transport

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -54,7 +54,7 @@ def loop(self) -> asyncio.AbstractEventLoop:
return self._loop

@contextlib.contextmanager
def request_transport(self, authinfo: AuthInfo) -> Iterator[Awaitable[Union['Transport', 'AsyncTransport']]]:
def request_transport(self, authinfo: AuthInfo) -> Iterator[Awaitable['Transport']]:
"""Request a transport from an authinfo. Because the client is not allowed to
request a transport immediately they will instead be given back a future
that can be awaited to get the transport::
Expand Down
6 changes: 3 additions & 3 deletions src/aiida/orm/authinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
###########################################################################
"""Module for the `AuthInfo` ORM class."""

from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Type

from aiida.common import exceptions
from aiida.manage import get_manager
Expand All @@ -21,7 +21,7 @@
from aiida.orm import Computer, User
from aiida.orm.implementation import StorageBackend
from aiida.orm.implementation.authinfos import BackendAuthInfo # noqa: F401
from aiida.transports import AsyncTransport, Transport
from aiida.transports import Transport

__all__ = ('AuthInfo',)

Expand Down Expand Up @@ -166,7 +166,7 @@ def get_workdir(self) -> str:
except KeyError:
return self.computer.get_workdir()

def get_transport(self) -> Union['Transport', 'AsyncTransport']:
def get_transport(self) -> 'Transport':
"""Return a fully configured transport that can be used to connect to the computer set for this instance."""
computer = self.computer
transport_type = computer.transport_type
Expand Down
2 changes: 1 addition & 1 deletion src/aiida/orm/computers.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ def is_user_enabled(self, user: 'User') -> bool:
# Return False if the user is not configured (in a sense, it is disabled for that user)
return False

def get_transport(self, user: Optional['User'] = None) -> Union['Transport', 'AsyncTransport']:
def get_transport(self, user: Optional['User'] = None) -> 'Transport':
"""Return a Transport class, configured with all correct parameters.
The Transport is closed (meaning that if you want to run any operation with
it, you have to open it first (i.e., e.g. for a SSH transport, you have
Expand Down
6 changes: 3 additions & 3 deletions src/aiida/orm/nodes/process/calculation/calcjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from aiida.parsers import Parser
from aiida.schedulers.datastructures import JobInfo, JobState
from aiida.tools.calculations import CalculationTools
from aiida.transports import AsyncTransport, Transport
from aiida.transports import Transport

__all__ = ('CalcJobNode',)

Expand Down Expand Up @@ -450,10 +450,10 @@ def get_authinfo(self) -> 'AuthInfo':

return computer.get_authinfo(self.user)

def get_transport(self) -> Union['Transport', 'AsyncTransport']:
def get_transport(self) -> 'Transport':
"""Return the transport for this calculation.
:return: Union['Transport', 'AsyncTransport'] configured
:return: Transport configured
with the `AuthInfo` associated to the computer of this node
"""
return self.get_authinfo().get_transport()
Expand Down
7 changes: 3 additions & 4 deletions src/aiida/orm/utils/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import os
import typing as t
from typing import Union

from aiida.orm.nodes.data.remote.base import RemoteData

Expand All @@ -21,14 +20,14 @@

from aiida import orm
from aiida.orm.implementation import StorageBackend
from aiida.transports import AsyncTransport, Transport
from aiida.transports import Transport


def clean_remote(transport: Union['Transport', 'AsyncTransport'], path: str) -> None:
def clean_remote(transport: Transport, path: str) -> None:
"""Recursively remove a remote folder, with the given absolute path, and all its contents. The path should be
made accessible through the transport channel, which should already be open
:param transport: an open Union['Transport', 'AsyncTransport'] channel
:param transport: an open Transport channel
:param path: an absolute path on the remote made available through the transport
"""
if not isinstance(path, str):
Expand Down
2 changes: 1 addition & 1 deletion src/aiida/plugins/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ def TransportFactory(entry_point_name: str, load: Literal[False]) -> EntryPoint:
def TransportFactory(
entry_point_name: str, load: bool = True
) -> Union[EntryPoint, Type['Transport'], Type['AsyncTransport']]:
"""Return the Union['Transport', 'AsyncTransport'] sub class registered under the given entry point.
"""Return the Transport sub class registered under the given entry point.
:param entry_point_name: the entry point name.
:param load: if True, load the matched entry point and return the loaded resource instead of the entry point itself.
Expand Down
5 changes: 2 additions & 3 deletions src/aiida/schedulers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import abc
import typing as t
from typing import Union

from aiida.common import exceptions, log, warnings
from aiida.common.datastructures import CodeRunMode
Expand All @@ -22,7 +21,7 @@
from aiida.schedulers.datastructures import JobInfo, JobResource, JobTemplate, JobTemplateCodeInfo

if t.TYPE_CHECKING:
from aiida.transports import AsyncTransport, Transport
from aiida.transports import Transport

__all__ = ('Scheduler', 'SchedulerError', 'SchedulerParsingError')

Expand Down Expand Up @@ -366,7 +365,7 @@ def transport(self):

return self._transport

def set_transport(self, transport: Union['Transport', 'AsyncTransport']):
def set_transport(self, transport: Transport):
"""Set the transport to be used to query the machine or to submit scripts.
This class assumes that the transport is open and active.
Expand Down
1 change: 1 addition & 0 deletions src/aiida/transports/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

__all__ = (
'AsyncTransport',
'BlockingTransport',
'SshTransport',
'Transport',
'TransportPath',
Expand Down
4 changes: 2 additions & 2 deletions src/aiida/transports/plugins/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@

from aiida.common.warnings import warn_deprecation
from aiida.transports import cli as transport_cli
from aiida.transports.transport import Transport, TransportInternalError, TransportPath, path_to_str
from aiida.transports.transport import BlockingTransport, TransportInternalError, TransportPath, path_to_str


# refactor or raise the limit: issue #1784
class LocalTransport(Transport):
class LocalTransport(BlockingTransport):
"""Support copy and command execution on the same host on which AiiDA is running via direct file copy and
execution commands.
Expand Down
4 changes: 2 additions & 2 deletions src/aiida/transports/plugins/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from aiida.common.escaping import escape_for_bash
from aiida.common.warnings import warn_deprecation

from ..transport import Transport, TransportInternalError, TransportPath, path_to_str
from ..transport import BlockingTransport, TransportInternalError, TransportPath, path_to_str

__all__ = ('SshTransport', 'convert_to_bool', 'parse_sshconfig')

Expand Down Expand Up @@ -62,7 +62,7 @@ def convert_to_bool(string):
raise ValueError('Invalid boolean value provided')


class SshTransport(Transport):
class SshTransport(BlockingTransport):
"""Support connection, command execution and data transfer to remote computers via SSH+SFTP."""

# Valid keywords accepted by the connect method of paramiko.SSHClient
Expand Down
7 changes: 5 additions & 2 deletions src/aiida/transports/plugins/ssh_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import glob
import os
from pathlib import Path, PurePath
from typing import Optional, Union
from typing import Optional

import asyncssh
import click
Expand Down Expand Up @@ -176,6 +176,9 @@ async def close_async(self):
await self._conn.wait_closed()
self._is_open = False

def chown(self, path, uid, gid):
raise NotImplementedError

def __str__(self):
return f"{'OPEN' if self._is_open else 'CLOSED'} [AsyncSshTransport]"

Expand Down Expand Up @@ -1199,7 +1202,7 @@ async def chown_async(self, path: TransportPath, uid: int, gid: int):

async def copy_from_remote_to_remote_async(
self,
transportdestination: Union['Transport', 'AsyncTransport'],
transportdestination: Transport,
remotesource: TransportPath,
remotedestination: TransportPath,
**kwargs,
Expand Down
Loading

0 comments on commit 1e78d7f

Please sign in to comment.