You are viewing an outdated version of the documentation.

This documentation is for an older version (1.4.7) of Dagster. You can view the version of this page from our latest release below.

Source code for dagster_airflow.dagster_asset_factory

from typing import AbstractSet, List, Mapping, Optional, Set, Tuple

from airflow.models.connection import Connection
from airflow.models.dag import DAG
from dagster import (
    AssetKey,
    AssetsDefinition,
    GraphDefinition,
    OutputMapping,
    TimeWindowPartitionsDefinition,
)
from dagster._core.definitions.graph_definition import create_adjacency_lists
from dagster._utils.schedules import is_valid_cron_schedule

from dagster_airflow.dagster_job_factory import make_dagster_job_from_airflow_dag
from dagster_airflow.utils import (
    DagsterAirflowError,
    normalized_name,
)


def _build_asset_dependencies(
    dag: DAG,
    graph: GraphDefinition,
    task_ids_by_asset_key: Mapping[AssetKey, AbstractSet[str]],
    upstream_dependencies_by_asset_key: Mapping[AssetKey, AbstractSet[AssetKey]],
) -> Tuple[AbstractSet[OutputMapping], Mapping[str, AssetKey], Mapping[str, Set[AssetKey]]]:
    """Builds the asset dependency graph for a given set of airflow task mappings and a dagster graph."""
    output_mappings = set()
    keys_by_output_name = {}
    internal_asset_deps: dict[str, Set[AssetKey]] = {}

    visited_nodes: dict[str, bool] = {}
    upstream_deps = set()

    def find_upstream_dependency(node_name: str) -> None:
        """Uses Depth-Firs-Search to find all upstream asset dependencies
        as described in task_ids_by_asset_key.
        """
        # node has been visited
        if visited_nodes[node_name]:
            return
        # mark node as visted
        visited_nodes[node_name] = True
        # traverse upstream nodes
        for output_handle in graph.dependency_structure.all_upstream_outputs_from_node(node_name):
            forward_node = output_handle.node_name
            match = False
            # find any assets produced by upstream nodes and add them to the internal asset deps
            for asset_key in task_ids_by_asset_key:
                if (
                    forward_node.replace(f"{normalized_name(dag.dag_id)}__", "")
                    in task_ids_by_asset_key[asset_key]
                ):
                    upstream_deps.add(asset_key)
                    match = True
            # don't traverse past nodes that have assets
            if not match:
                find_upstream_dependency(forward_node)

    # iterate through each asset to find all upstream asset dependencies
    for asset_key in task_ids_by_asset_key:
        asset_upstream_deps = set()
        for task_id in task_ids_by_asset_key[asset_key]:
            visited_nodes = {s.name: False for s in graph.nodes}
            upstream_deps = set()
            find_upstream_dependency(normalized_name(dag.dag_id, task_id))
            for dep in upstream_deps:
                asset_upstream_deps.add(dep)
            keys_by_output_name[f"result_{normalized_name(dag.dag_id, task_id)}"] = asset_key
            output_mappings.add(
                OutputMapping(
                    graph_output_name=f"result_{normalized_name(dag.dag_id, task_id)}",
                    mapped_node_name=normalized_name(dag.dag_id, task_id),
                    mapped_node_output_name="airflow_task_complete",  # Default output name
                )
            )

        # the tasks for a given asset should have the same internal deps
        for task_id in task_ids_by_asset_key[asset_key]:
            if f"result_{normalized_name(dag.dag_id, task_id)}" in internal_asset_deps:
                internal_asset_deps[f"result_{normalized_name(dag.dag_id, task_id)}"].update(
                    asset_upstream_deps
                )
            else:
                internal_asset_deps[f"result_{normalized_name(dag.dag_id, task_id)}"] = (
                    asset_upstream_deps
                )

    # add new upstream asset dependencies to the internal deps
    for asset_key in upstream_dependencies_by_asset_key:
        for key in keys_by_output_name:
            if keys_by_output_name[key] == asset_key:
                internal_asset_deps[key].update(upstream_dependencies_by_asset_key[asset_key])

    return (output_mappings, keys_by_output_name, internal_asset_deps)


[docs]def load_assets_from_airflow_dag( dag: DAG, task_ids_by_asset_key: Mapping[AssetKey, AbstractSet[str]] = {}, upstream_dependencies_by_asset_key: Mapping[AssetKey, AbstractSet[AssetKey]] = {}, connections: Optional[List[Connection]] = None, ) -> List[AssetsDefinition]: """[Experimental] Construct Dagster Assets for a given Airflow DAG. Args: dag (DAG): The Airflow DAG to compile into a Dagster job task_ids_by_asset_key (Optional[Mapping[AssetKey, AbstractSet[str]]]): A mapping from asset keys to task ids. Used break up the Airflow Dag into multiple SDAs upstream_dependencies_by_asset_key (Optional[Mapping[AssetKey, AbstractSet[AssetKey]]]): A mapping from upstream asset keys to assets provided in task_ids_by_asset_key. Used to declare new upstream SDA depenencies. connections (List[Connection]): List of Airflow Connections to be created in the Airflow DB Returns: List[AssetsDefinition] """ cron_schedule = dag.normalized_schedule_interval if cron_schedule is not None and not is_valid_cron_schedule(str(cron_schedule)): raise DagsterAirflowError(f"Invalid cron schedule: {cron_schedule} in DAG {dag.dag_id}") job = make_dagster_job_from_airflow_dag(dag, connections=connections) graph = job._graph_def # noqa: SLF001 start_date = dag.start_date if dag.start_date else dag.default_args.get("start_date") if start_date is None: raise DagsterAirflowError(f"Invalid start_date: {start_date} in DAG {dag.dag_id}") # leaf nodes have no downstream nodes forward_edges, _ = create_adjacency_lists(graph.nodes, graph.dependency_structure) leaf_nodes = { node_name.replace(f"{normalized_name(dag.dag_id)}__", "") for node_name, downstream_nodes in forward_edges.items() if not downstream_nodes } mutated_task_ids_by_asset_key: dict[AssetKey, set[str]] = {} if task_ids_by_asset_key is None or task_ids_by_asset_key == {}: # if no mappings are provided the dag becomes a single SDA task_ids_by_asset_key = {AssetKey(dag.dag_id): leaf_nodes} else: # if mappings were provide any unmapped leaf nodes are added to a default asset used_nodes: set[str] = set() for key in task_ids_by_asset_key: used_nodes.update(task_ids_by_asset_key[key]) mutated_task_ids_by_asset_key[AssetKey(dag.dag_id)] = leaf_nodes - used_nodes for key in task_ids_by_asset_key: if key not in mutated_task_ids_by_asset_key: mutated_task_ids_by_asset_key[key] = set(task_ids_by_asset_key[key]) else: mutated_task_ids_by_asset_key[key].update(task_ids_by_asset_key[key]) output_mappings, keys_by_output_name, internal_asset_deps = _build_asset_dependencies( dag, graph, mutated_task_ids_by_asset_key, upstream_dependencies_by_asset_key ) new_graph = graph.copy( output_mappings=list(output_mappings), ) asset_def = AssetsDefinition.from_graph( graph_def=new_graph, partitions_def=( TimeWindowPartitionsDefinition( cron_schedule=str(cron_schedule), timezone=dag.timezone.name, start=start_date.strftime("%Y-%m-%dT%H:%M:%S"), fmt="%Y-%m-%dT%H:%M:%S", ) if cron_schedule is not None else None ), group_name=dag.dag_id, keys_by_output_name=keys_by_output_name, internal_asset_deps=internal_asset_deps, can_subset=True, ) return [asset_def]