diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index a091e0c3f96f1..d2921ccf67e8a 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -303,19 +303,23 @@ def _process_tasks(self, task_tuples: list[TaskTuple]) -> None: self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) self.running.add(key) - def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, info=None) -> None: + def change_state( + self, key: TaskInstanceKey, state: TaskInstanceState, info=None, remove_running=True + ) -> None: """ Change state of the task. - :param info: Executor information for the task instance :param key: Unique key for the task instance :param state: State to set for the task. + :param info: Executor information for the task instance + :param remove_running: Whether or not to remove the TI key from running set """ self.log.debug("Changing state: %s", key) - try: - self.running.remove(key) - except KeyError: - self.log.debug("Could not find key: %s", key) + if remove_running: + try: + self.running.remove(key) + except KeyError: + self.log.debug("Could not find key: %s", key) self.event_buffer[key] = state, info def fail(self, key: TaskInstanceKey, info=None) -> None: @@ -345,6 +349,15 @@ def queued(self, key: TaskInstanceKey, info=None) -> None: """ self.change_state(key, TaskInstanceState.QUEUED, info) + def running_state(self, key: TaskInstanceKey, info=None) -> None: + """ + Set running state for the event. + + :param info: Executor information for the task instance + :param key: Unique key for the task instance + """ + self.change_state(key, TaskInstanceState.RUNNING, info, remove_running=False) + def get_event_buffer(self, dag_ids=None) -> dict[TaskInstanceKey, EventBufferValueType]: """ Return and flush the event buffer. diff --git a/airflow/executors/debug_executor.py b/airflow/executors/debug_executor.py index a315ee31f96b1..80fb673cab844 100644 --- a/airflow/executors/debug_executor.py +++ b/airflow/executors/debug_executor.py @@ -155,8 +155,3 @@ def end(self) -> None: def terminate(self) -> None: self._terminated.set() - - def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, info=None) -> None: - self.log.debug("Popping %s from executor task queue.", key) - self.running.remove(key) - self.event_buffer[key] = state, info diff --git a/airflow/jobs/scheduler_job_runner.py b/airflow/jobs/scheduler_job_runner.py index 631de5692e468..49a065b5f56e3 100644 --- a/airflow/jobs/scheduler_job_runner.py +++ b/airflow/jobs/scheduler_job_runner.py @@ -692,7 +692,12 @@ def _process_executor_events(self, session: Session) -> int: ti_primary_key_to_try_number_map[ti_key.primary] = ti_key.try_number self.log.info("Received executor event with state %s for task instance %s", state, ti_key) - if state in (TaskInstanceState.FAILED, TaskInstanceState.SUCCESS, TaskInstanceState.QUEUED): + if state in ( + TaskInstanceState.FAILED, + TaskInstanceState.SUCCESS, + TaskInstanceState.QUEUED, + TaskInstanceState.RUNNING, + ): tis_with_right_state.append(ti_key) # Return if no finished tasks @@ -711,7 +716,7 @@ def _process_executor_events(self, session: Session) -> int: buffer_key = ti.key.with_try_number(try_number) state, info = event_buffer.pop(buffer_key) - if state == TaskInstanceState.QUEUED: + if state in (TaskInstanceState.QUEUED, TaskInstanceState.RUNNING): ti.external_executor_id = info self.log.info("Setting external_id for %s to %s", ti, info) continue diff --git a/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py b/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py index 6730e5716822e..c5e7e3d6b46c3 100644 --- a/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +++ b/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py @@ -400,7 +400,12 @@ def attempt_task_runs(self): else: task = run_task_response["tasks"][0] self.active_workers.add_task(task, task_key, queue, cmd, exec_config, attempt_number) - self.queued(task_key, task.task_arn) + try: + self.running_state(task_key, task.task_arn) + except AttributeError: + # running_state is newly added, and only needed to support task adoption (an optional + # executor feature). + pass if failure_reasons: self.log.error( "Pending ECS tasks failed to launch for the following reasons: %s. Retrying later.", diff --git a/airflow/providers/celery/executors/celery_executor.py b/airflow/providers/celery/executors/celery_executor.py index 0b4293cde728f..1d4342f2940d8 100644 --- a/airflow/providers/celery/executors/celery_executor.py +++ b/airflow/providers/celery/executors/celery_executor.py @@ -368,8 +368,14 @@ def update_all_task_states(self) -> None: if state: self.update_task_state(key, state, info) - def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, info=None) -> None: - super().change_state(key, state, info) + def change_state( + self, key: TaskInstanceKey, state: TaskInstanceState, info=None, remove_running=True + ) -> None: + try: + super().change_state(key, state, info, remove_running=remove_running) + except AttributeError: + # Earlier versions of the BaseExecutor don't accept the remove_running parameter for this method + super().change_state(key, state, info) self.tasks.pop(key, None) def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) -> None: diff --git a/tests/executors/test_base_executor.py b/tests/executors/test_base_executor.py index 432ed867acdd9..0e75751fafee3 100644 --- a/tests/executors/test_base_executor.py +++ b/tests/executors/test_base_executor.py @@ -33,7 +33,7 @@ from airflow.models.baseoperator import BaseOperator from airflow.models.taskinstance import TaskInstance, TaskInstanceKey from airflow.utils import timezone -from airflow.utils.state import State +from airflow.utils.state import State, TaskInstanceState def test_supports_sentry(): @@ -363,3 +363,54 @@ def test_running_retry_attempt_type(loop_duration, total_tries): assert a.elapsed > min_seconds_for_test assert a.total_tries == total_tries assert a.tries_after_min == 1 + + +def test_state_fail(): + executor = BaseExecutor() + key = TaskInstanceKey("my_dag1", "my_task1", timezone.utcnow(), 1) + executor.running.add(key) + info = "info" + executor.fail(key, info=info) + assert not executor.running + assert executor.event_buffer[key] == (TaskInstanceState.FAILED, info) + + +def test_state_success(): + executor = BaseExecutor() + key = TaskInstanceKey("my_dag1", "my_task1", timezone.utcnow(), 1) + executor.running.add(key) + info = "info" + executor.success(key, info=info) + assert not executor.running + assert executor.event_buffer[key] == (TaskInstanceState.SUCCESS, info) + + +def test_state_queued(): + executor = BaseExecutor() + key = TaskInstanceKey("my_dag1", "my_task1", timezone.utcnow(), 1) + executor.running.add(key) + info = "info" + executor.queued(key, info=info) + assert not executor.running + assert executor.event_buffer[key] == (TaskInstanceState.QUEUED, info) + + +def test_state_generic(): + executor = BaseExecutor() + key = TaskInstanceKey("my_dag1", "my_task1", timezone.utcnow(), 1) + executor.running.add(key) + info = "info" + executor.queued(key, info=info) + assert not executor.running + assert executor.event_buffer[key] == (TaskInstanceState.QUEUED, info) + + +def test_state_running(): + executor = BaseExecutor() + key = TaskInstanceKey("my_dag1", "my_task1", timezone.utcnow(), 1) + executor.running.add(key) + info = "info" + executor.running_state(key, info=info) + # Running state should not remove a command as running + assert executor.running + assert executor.event_buffer[key] == (TaskInstanceState.RUNNING, info) diff --git a/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py b/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py index fd7bf6772620a..524360dbac2ad 100644 --- a/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py +++ b/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py @@ -367,7 +367,8 @@ def test_stopped_tasks(self): class TestAwsEcsExecutor: """Tests the AWS ECS Executor.""" - def test_execute(self, mock_airflow_key, mock_executor): + @mock.patch("airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor.change_state") + def test_execute(self, change_state_mock, mock_airflow_key, mock_executor): """Test execution from end-to-end.""" airflow_key = mock_airflow_key() @@ -393,6 +394,9 @@ def test_execute(self, mock_airflow_key, mock_executor): # Task is stored in active worker. assert 1 == len(mock_executor.active_workers) assert ARN1 in mock_executor.active_workers.task_by_key(airflow_key).task_arn + change_state_mock.assert_called_once_with( + airflow_key, TaskInstanceState.RUNNING, ARN1, remove_running=False + ) @mock.patch.object(ecs_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0)) def test_success_execute_api_exception(self, mock_backoff, mock_executor):