Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: kedro-airflow group in memory nodes #241

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions kedro-airflow/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,19 @@ See ["What if I want to use a different Jinja2 template?"](#what-if-i-want-to-us
The [rich offering](https://airflow.apache.org/docs/apache-airflow-providers/operators-and-hooks-ref/index.html) of operators means that the `kedro-airflow` plugin is providing templates for specific operators.
The default template provided by `kedro-airflow` uses the `BaseOperator`.

### Can I group nodes together?

When running Kedro nodes using Airflow, MemoryDatasets are often not shared across operators.
This will cause the DAG run to fail.

MemoryDatasets may be used to provide logical separation between nodes in Kedro, without the overhead of needing to write to disk (and in the case of distributed running needing multiple executors).

Nodes that are connected through MemoryDatasets are grouped together via the `--group-in-memory` flag.
This preserves the option to have logical separation in Kedro, with little computational overhead.

It is possible to use [task groups](https://docs.astronomer.io/learn/task-groups) by changing the template.
See ["What if I want to use a different Jinja2 template?"](#what-if-i-want-to-use-a-different-jinja2-template) for instructions on using custom templates.

## Can I contribute?

Yes! Want to help build Kedro-Airflow? Check out our guide to [contributing](https://github.com/kedro-org/kedro-plugins/blob/main/kedro-airflow/CONTRIBUTING.md).
Expand Down
1 change: 1 addition & 0 deletions kedro-airflow/RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Upcoming Release
* Option to group MemoryDatasets in the same Airflow task (breaking change for custom template via `--jinja-file`).

# Release 0.8.0
* Added support for Kedro 0.19.x
Expand Down
19 changes: 10 additions & 9 deletions kedro-airflow/kedro_airflow/airflow_dag_template.j2
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations

from datetime import datetime, timedelta
from pathlib import Path

Expand All @@ -16,7 +17,7 @@ class KedroOperator(BaseOperator):
self,
package_name: str,
pipeline_name: str,
node_name: str,
node_name: str | list[str],
project_path: str | Path,
env: str,
*args, **kwargs
Expand All @@ -30,10 +31,10 @@ class KedroOperator(BaseOperator):

def execute(self, context):
configure_project(self.package_name)
with KedroSession.create(project_path=self.project_path,
env=self.env) as session:
session.run(self.pipeline_name, node_names=[self.node_name])

with KedroSession.create(self.project_path, env=self.env) as session:
if isinstance(self.node_name, str):
self.node_name = [self.node_name]
session.run(self.pipeline_name, node_names=self.node_name)

# Kedro settings required to run your pipeline
env = "{{ env }}"
Expand All @@ -60,17 +61,17 @@ with DAG(
)
) as dag:
tasks = {
{% for node in pipeline.nodes %} "{{ node.name | safe | slugify }}": KedroOperator(
task_id="{{ node.name | safe | slugify }}",
{% for node_name, node_list in nodes.items() %} "{{ node_name | safe | slugify }}": KedroOperator(
task_id="{{ node_name | safe | slugify }}",
package_name=package_name,
pipeline_name=pipeline_name,
node_name="{{ node.name | safe }}",
node_name={% if node_list | length > 1 %}[{% endif %}{% for node in node_list %}"{{ node.name | safe | slugify }}"{% if not loop.last %}, {% endif %}{% endfor %}{% if node_list | length > 1 %}]{% endif %},
project_path=project_path,
env=env,
),
{% endfor %} }

{% for parent_node, child_nodes in dependencies.items() -%}
{% for child in child_nodes %} tasks["{{ parent_node.name | safe | slugify }}"] >> tasks["{{ child.name | safe | slugify }}"]
{% for child in child_nodes %} tasks["{{ parent_node | safe | slugify }}"] >> tasks["{{ child | safe | slugify }}"]
{% endfor %}
{%- endfor %}
97 changes: 97 additions & 0 deletions kedro-airflow/kedro_airflow/grouping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from __future__ import annotations

from collections import defaultdict

from kedro.io import DataCatalog, MemoryDataset
from kedro.pipeline.node import Node
from kedro.pipeline.pipeline import Pipeline


def _is_memory_dataset(catalog, dataset_name: str) -> bool:
if dataset_name == "parameters" or dataset_name.startswith("params:"):
return False

dataset = catalog._datasets.get(dataset_name, None)
return dataset is not None and isinstance(dataset, MemoryDataset)


def get_memory_datasets(catalog: DataCatalog, pipeline: Pipeline) -> set[str]:
"""Gather all datasets in the pipeline that are of type MemoryDataset, excluding 'parameters'."""
return {
dataset_name
for dataset_name in pipeline.datasets()
if _is_memory_dataset(catalog, dataset_name)
}


def node_sequence_name(node_sequence: list[Node]) -> str:
return "_".join([node.name for node in node_sequence])


def group_memory_nodes(catalog: DataCatalog, pipeline: Pipeline):
sbrugman marked this conversation as resolved.
Show resolved Hide resolved
"""
Nodes that are connected through MemoryDatasets cannot be distributed across
multiple machines, e.g. be in different Kubernetes pods. This function
groups nodes that are connected through MemoryDatasets in the pipeline
together. Essentially, this computes connected components over the graph of
nodes connected by MemoryDatasets.
"""
# get all memory datasets in the pipeline
memory_datasets = get_memory_datasets(catalog, pipeline)

# Node sequences
node_sequences = []

# Mapping from dataset name -> node sequence index
sequence_map = {}
for node in pipeline.nodes:
if all(o not in memory_datasets for o in node.inputs + node.outputs):
# standalone node
node_sequences.append([node])
else:
if all(i not in memory_datasets for i in node.inputs):
# start of a sequence; create a new sequence and store the id
node_sequences.append([node])
sequence_id = len(node_sequences) - 1
else:
# continuation of a sequence; retrieve sequence_id
sequence_id = None
for i in node.inputs:
if i in memory_datasets:
assert sequence_id is None or sequence_id == sequence_map[i]
sequence_id = sequence_map[i]

# Append to map
node_sequences[sequence_id].append(node)

# map outputs to sequence_id
for o in node.outputs:
if o in memory_datasets:
sequence_map[o] = sequence_id

# Named node sequences
nodes = {
node_sequence_name(node_sequence): node_sequence
for node_sequence in node_sequences
}

# Inverted mapping
node_mapping = {
node.name: sequence_name
for sequence_name, node_sequence in nodes.items()
for node in node_sequence
}

# Grouped dependencies
dependencies = defaultdict(list)
for node, parent_nodes in pipeline.node_dependencies.items():
for parent in parent_nodes:
parent_name = node_mapping[parent.name]
node_name = node_mapping[node.name]
if parent_name != node_name and (
parent_name not in dependencies
or node_name not in dependencies[parent_name]
):
dependencies[parent_name].append(node_name)

return nodes, dependencies
26 changes: 22 additions & 4 deletions kedro-airflow/kedro_airflow/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from kedro.framework.startup import ProjectMetadata, bootstrap_project
from slugify import slugify

from kedro_airflow.grouping import group_memory_nodes

PIPELINE_ARG_HELP = """Name of the registered pipeline to convert.
If not set, the '__default__' pipeline is used. This argument supports
passing multiple values using `--pipeline [p1] --pipeline [p2]`.
Expand Down Expand Up @@ -100,6 +102,14 @@ def _get_pipeline_config(config_airflow: dict, params: dict, pipeline_name: str)
default=Path(__file__).parent / "airflow_dag_template.j2",
help="The template file for the generated Airflow dags",
)
@click.option(
"-g",
"--group-in-memory",
is_flag=True,
default=False,
help="Group nodes with at least one MemoryDataset as input/output together, "
"as they do not persist between Airflow operators.",
)
@click.option(
"--params",
type=click.UNPROCESSED,
Expand All @@ -114,6 +124,7 @@ def create( # noqa: PLR0913
env,
target_path,
jinja_file,
group_in_memory,
params,
convert_all: bool,
):
Expand Down Expand Up @@ -165,13 +176,20 @@ def create( # noqa: PLR0913
else f"{package_name}_{name}_dag.py"
)

dependencies = defaultdict(list)
for node, parent_nodes in pipeline.node_dependencies.items():
for parent in parent_nodes:
dependencies[parent].append(node)
# group memory nodes
if group_in_memory:
nodes, dependencies = group_memory_nodes(context.catalog, pipeline)
else:
nodes = {node.name: [node] for node in pipeline.nodes}

dependencies = defaultdict(list)
for node, parent_nodes in pipeline.node_dependencies.items():
for parent in parent_nodes:
dependencies[parent.name].append(node.name)

template.stream(
dag_name=package_name,
nodes=nodes,
dependencies=dependencies,
env=env,
pipeline_name=name,
Expand Down
2 changes: 1 addition & 1 deletion kedro-airflow/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ description = "Kedro-Airflow makes it easy to deploy Kedro projects to Airflow"
requires-python = ">=3.8"
license = {text = "Apache Software License (Apache 2.0)"}
dependencies = [
"kedro>=0.17.5",
"kedro>=0.19.0",
"python-slugify>=4.0",
"semver>=2.10", # Needs to be at least 2.10.0 to make use of `VersionInfo.match`.
]
Expand Down
128 changes: 128 additions & 0 deletions kedro-airflow/tests/test_node_grouping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from __future__ import annotations

from typing import Any

import pytest
from kedro.io import AbstractDataset, DataCatalog, MemoryDataset
from kedro.pipeline import node
from kedro.pipeline.modular_pipeline import pipeline as modular_pipeline

from kedro_airflow.grouping import _is_memory_dataset, group_memory_nodes


class TestDataset(AbstractDataset):
def _save(self, data) -> None:
pass

def _describe(self) -> dict[str, Any]:
return {}

def _load(self):
return []


@pytest.mark.parametrize(
"memory_nodes,expected_nodes,expected_dependencies",
[
(
["ds3", "ds6"],
[["f1"], ["f2", "f3", "f4", "f6", "f7"], ["f5"]],
{"f1": ["f2_f3_f4_f6_f7"], "f2_f3_f4_f6_f7": ["f5"]},
),
(
["ds3"],
[["f1"], ["f2", "f3", "f4", "f7"], ["f5"], ["f6"]],
{"f1": ["f2_f3_f4_f7"], "f2_f3_f4_f7": ["f5", "f6"]},
),
(
[],
[["f1"], ["f2"], ["f3"], ["f4"], ["f5"], ["f6"], ["f7"]],
{"f1": ["f2"], "f2": ["f3", "f4", "f5", "f7"], "f4": ["f6", "f7"]},
),
],
)
def test_group_memory_nodes(
memory_nodes: list[str],
expected_nodes: list[list[str]],
expected_dependencies: dict[str, list[str]],
):
"""Check the grouping of memory nodes."""
nodes = [f"ds{i}" for i in range(1, 10)]
assert all(node_name in nodes for node_name in memory_nodes)

mock_catalog = DataCatalog()
for dataset_name in nodes:
if dataset_name in memory_nodes:
dataset = MemoryDataset()
else:
dataset = TestDataset()
mock_catalog.add(dataset_name, dataset)

def identity_one_to_one(x):
return x

mock_pipeline = modular_pipeline(
[
node(
func=identity_one_to_one,
inputs="ds1",
outputs="ds2",
name="f1",
),
node(
func=lambda x: (x, x),
inputs="ds2",
outputs=["ds3", "ds4"],
name="f2",
),
node(
func=identity_one_to_one,
inputs="ds3",
outputs="ds5",
name="f3",
),
node(
func=identity_one_to_one,
inputs="ds3",
outputs="ds6",
name="f4",
),
node(
func=identity_one_to_one,
inputs="ds4",
outputs="ds8",
name="f5",
),
node(
func=identity_one_to_one,
inputs="ds6",
outputs="ds7",
name="f6",
),
node(
func=lambda x, y: x,
inputs=["ds3", "ds6"],
outputs="ds9",
name="f7",
),
],
)

nodes, dependencies = group_memory_nodes(mock_catalog, mock_pipeline)
sequence = [
[node_.name for node_ in node_sequence] for node_sequence in nodes.values()
]
assert sequence == expected_nodes
assert dict(dependencies) == expected_dependencies


def test_is_memory_dataset():
catalog = DataCatalog()
catalog.add("parameters", {"hello": "world"})
catalog.add("params:hello", "world")
catalog.add("my_dataset", MemoryDataset(True))
catalog.add("test_dataset", TestDataset())
assert not _is_memory_dataset(catalog, "parameters")
assert not _is_memory_dataset(catalog, "params:hello")
assert _is_memory_dataset(catalog, "my_dataset")
assert not _is_memory_dataset(catalog, "test_dataset")
2 changes: 2 additions & 0 deletions kedro-airflow/tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
("hello_world", "__default__", ["airflow", "create"]),
# Test execution with alternate pipeline name
("hello_world", "ds", ["airflow", "create", "--pipeline", "ds"]),
# Test with grouping
("hello_world", "__default__", ["airflow", "create", "--group-in-memory"]),
],
)
def test_create_airflow_dag(dag_name, pipeline_name, command, cli_runner, metadata):
Expand Down
Loading