You are viewing an outdated version of the documentation.

This documentation is for an older version (1.4.7) of Dagster. You can view the version of this page from our latest release below.

Source code for dagster._core.storage.upath_io_manager

import asyncio
import inspect
from abc import abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union

from fsspec import AbstractFileSystem
from fsspec.implementations.local import LocalFileSystem

from dagster import (
    InputContext,
    MetadataValue,
    MultiPartitionKey,
    OutputContext,
    _check as check,
)
from dagster._core.storage.memoizable_io_manager import MemoizableIOManager

if TYPE_CHECKING:
    from upath import UPath


[docs]class UPathIOManager(MemoizableIOManager): """Abstract IOManager base class compatible with local and cloud storage via `universal-pathlib` and `fsspec`. Features: - handles partitioned assets - handles loading a single upstream partition - handles loading multiple upstream partitions (with respect to :py:class:`PartitionMapping`) - supports loading multiple partitions concurrently with async `load_from_path` method - the `get_metadata` method can be customized to add additional metadata to the output - the `allow_missing_partitions` metadata value can be set to `True` to skip missing partitions (the default behavior is to raise an error) """ extension: Optional[str] = None # override in child class def __init__( self, base_path: Optional["UPath"] = None, ): from upath import UPath assert not self.extension or "." in self.extension self._base_path = base_path or UPath(".") @abstractmethod def dump_to_path(self, context: OutputContext, obj: Any, path: "UPath"): """Child classes should override this method to write the object to the filesystem.""" @abstractmethod def load_from_path(self, context: InputContext, path: "UPath") -> Any: """Child classes should override this method to load the object from the filesystem.""" @property def fs(self) -> AbstractFileSystem: """Utility function to get the IOManager filesystem. Returns: AbstractFileSystem: fsspec filesystem. """ from upath import UPath if isinstance(self._base_path, UPath): return self._base_path.fs elif isinstance(self._base_path, Path): return LocalFileSystem() else: raise ValueError(f"Unsupported base_path type: {type(self._base_path)}") @property def storage_options(self) -> Dict[str, Any]: """Utility function to get the fsspec storage_options which are often consumed by various I/O functions. Returns: Dict[str, Any]: fsspec storage_options. """ from upath import UPath if isinstance(self._base_path, UPath): return self._base_path._kwargs.copy() # noqa elif isinstance(self._base_path, Path): return {} else: raise ValueError(f"Unsupported base_path type: {type(self._base_path)}") def get_metadata( self, context: OutputContext, obj: Any, ) -> Dict[str, MetadataValue]: """Child classes should override this method to add custom metadata to the outputs.""" return {} # Read/write operations on paths can generally be handled by methods on the # UPath class, but when the backend requires credentials, this isn't # always possible. Override these path_* methods to provide custom # implementations for targeting backends that require authentication. def unlink(self, path: "UPath") -> None: """Remove the file or object at the provided path.""" path.unlink() def path_exists(self, path: "UPath") -> bool: """Check if a file or object exists at the provided path.""" return path.exists() def make_directory(self, path: "UPath"): """Create a directory at the provided path. Override as a no-op if the target backend doesn't use directories. """ path.mkdir(parents=True, exist_ok=True) def has_output(self, context: OutputContext) -> bool: return self.path_exists(self._get_path(context)) def _with_extension(self, path: "UPath") -> "UPath": return path.with_suffix(path.suffix + self.extension) if self.extension else path def _get_path_without_extension(self, context: Union[InputContext, OutputContext]) -> "UPath": if context.has_asset_key: context_path = self.get_asset_relative_path(context) else: # we are dealing with an op output context_path = self.get_op_output_relative_path(context) return self._base_path.joinpath(context_path) def get_asset_relative_path(self, context: Union[InputContext, OutputContext]) -> "UPath": from upath import UPath # we are not using context.get_asset_identifier() because it already includes the partition_key return UPath(*context.asset_key.path) def get_op_output_relative_path(self, context: Union[InputContext, OutputContext]) -> "UPath": from upath import UPath return UPath(*context.get_identifier()) def get_loading_input_log_message(self, path: "UPath") -> str: return f"Loading file from: {path} using {self.__class__.__name__}..." def get_writing_output_log_message(self, path: "UPath") -> str: return f"Writing file at: {path} using {self.__class__.__name__}..." def get_loading_input_partition_log_message(self, path: "UPath", partition_key: str) -> str: return f"Loading partition {partition_key} from {path} using {self.__class__.__name__}..." def get_missing_partition_log_message(self, partition_key: str) -> str: return ( f"Couldn't load partition {partition_key} and skipped it " "because the input metadata includes allow_missing_partitions=True" ) def _get_path(self, context: Union[InputContext, OutputContext]) -> "UPath": """Returns the I/O path for a given context. Should not be used with partitions (use `_get_paths_for_partitions` instead). """ path = self._get_path_without_extension(context) return self._with_extension(path) def get_path_for_partition( self, context: Union[InputContext, OutputContext], path: "UPath", partition: str ) -> "UPath": """Override this method if you want to use a different partitioning scheme (for example, if the saving function handles partitioning instead). The extension will be added later. Args: context (Union[InputContext, OutputContext]): The context for the I/O operation. path (UPath): The path to the file or object. partition (str): Formatted partition/multipartition key Returns: UPath: The path to the file with the partition key appended. """ return path / partition def _get_paths_for_partitions( self, context: Union[InputContext, OutputContext] ) -> Dict[str, "UPath"]: """Returns a dict of partition_keys into I/O paths for a given context.""" if not context.has_asset_partitions: raise TypeError( f"Detected {context.dagster_type.typing_type} input type " "but the asset is not partitioned" ) def _formatted_multipartitioned_path(partition_key: MultiPartitionKey) -> str: ordered_dimension_keys = [ key[1] for key in sorted(partition_key.keys_by_dimension.items(), key=lambda x: x[0]) ] return "/".join(ordered_dimension_keys) formatted_partition_keys = { partition_key: ( _formatted_multipartitioned_path(partition_key) if isinstance(partition_key, MultiPartitionKey) else partition_key ) for partition_key in context.asset_partition_keys } asset_path = self._get_path_without_extension(context) return { partition_key: self._with_extension( self.get_path_for_partition(context, asset_path, partition) ) for partition_key, partition in formatted_partition_keys.items() } def _get_multipartition_backcompat_paths( self, context: Union[InputContext, OutputContext] ) -> Mapping[str, "UPath"]: if not context.has_asset_partitions: raise TypeError( f"Detected {context.dagster_type.typing_type} input type " "but the asset is not partitioned" ) partition_keys = context.asset_partition_keys asset_path = self._get_path_without_extension(context) return { partition_key: self._with_extension(asset_path / partition_key) for partition_key in partition_keys if isinstance(partition_key, MultiPartitionKey) } def _load_single_input( self, path: "UPath", context: InputContext, backcompat_path: Optional["UPath"] = None ) -> Any: context.log.debug(self.get_loading_input_log_message(path)) try: obj = self.load_from_path(context=context, path=path) if asyncio.iscoroutine(obj): obj = asyncio.run(obj) except FileNotFoundError as e: if backcompat_path is not None: try: obj = self.load_from_path(context=context, path=backcompat_path) if asyncio.iscoroutine(obj): obj = asyncio.run(obj) context.log.debug( f"File not found at {path}. Loaded instead from backcompat path:" f" {backcompat_path}" ) except FileNotFoundError: raise e else: raise e context.add_input_metadata({"path": MetadataValue.path(str(path))}) return obj def _load_partition_from_path( self, context: InputContext, partition_key: str, path: "UPath", backcompat_path: Optional["UPath"] = None, ) -> Any: """1. Try to load the partition from the normal path. 2. If it was not found, try to load it from the backcompat path. 3. If allow_missing_partitions metadata is True, skip the partition if it was not found in any of the paths. Otherwise, raise an error. Args: context (InputContext): IOManager Input context partition_key (str): the partition key corresponding to the partition being loaded path (UPath): The path to the partition. backcompat_path (Optional[UPath]): The path to the partition in the backcompat location. Returns: Any: The object loaded from the partition. """ allow_missing_partitions = ( context.metadata.get("allow_missing_partitions", False) if context.metadata is not None else False ) try: context.log.debug(self.get_loading_input_partition_log_message(path, partition_key)) obj = self.load_from_path(context=context, path=path) return obj except FileNotFoundError as e: if backcompat_path is not None: try: obj = self.load_from_path(context=context, path=path) context.log.debug( f"File not found at {path}. Loaded instead from backcompat path:" f" {backcompat_path}" ) return obj except FileNotFoundError as e: if allow_missing_partitions: context.log.warning(self.get_missing_partition_log_message(partition_key)) return None else: raise e if allow_missing_partitions: context.log.warning(self.get_missing_partition_log_message(partition_key)) return None else: raise e def _load_multiple_inputs(self, context: InputContext) -> Dict[str, Any]: # load multiple partitions paths = self._get_paths_for_partitions(context) # paths for normal partitions backcompat_paths = self._get_multipartition_backcompat_paths( context ) # paths for multipartitions context.log.debug(f"Loading {len(paths)} partitions...") objs = {} if not inspect.iscoroutinefunction(self.load_from_path): for partition_key in context.asset_partition_keys: obj = self._load_partition_from_path( context, partition_key, paths[partition_key], backcompat_paths.get(partition_key), ) if obj is not None: # in case some partitions were skipped objs[partition_key] = obj return objs else: # load_from_path returns a coroutine, so we need to await the results async def collect(): loop = asyncio.get_running_loop() tasks = [] for partition_key in context.asset_partition_keys: tasks.append( loop.create_task( self._load_partition_from_path( context, partition_key, paths[partition_key], backcompat_paths.get(partition_key), ) ) ) results = await asyncio.gather(*tasks, return_exceptions=True) # need to handle missing partitions here because exceptions don't get propagated from async calls allow_missing_partitions = ( context.metadata.get("allow_missing_partitions", False) if context.metadata is not None else False ) results_without_errors = [] found_errors = False for partition_key, result in zip(context.asset_partition_keys, results): if isinstance(result, FileNotFoundError): if allow_missing_partitions: context.log.warning( self.get_missing_partition_log_message(partition_key) ) else: context.log.error(str(result)) found_errors = True elif isinstance(result, Exception): context.log.error(str(result)) found_errors = True else: results_without_errors.append(result) if found_errors: raise RuntimeError( f"{len(paths) - len(results_without_errors)} partitions could not be loaded" ) return results_without_errors awaited_objects = asyncio.get_event_loop().run_until_complete(collect()) return { partition_key: awaited_object for partition_key, awaited_object in zip( context.asset_partition_keys, awaited_objects ) if awaited_object is not None } def load_input(self, context: InputContext) -> Union[Any, Dict[str, Any]]: # If no asset key, we are dealing with an op output which is always non-partitioned if not context.has_asset_key or not context.has_asset_partitions: path = self._get_path(context) return self._load_single_input(path, context) else: asset_partition_keys = context.asset_partition_keys if len(asset_partition_keys) == 0: return None elif len(asset_partition_keys) == 1: paths = self._get_paths_for_partitions(context) check.invariant(len(paths) == 1, f"Expected 1 path, but got {len(paths)}") path = list(paths.values())[0] backcompat_paths = self._get_multipartition_backcompat_paths(context) backcompat_path = ( None if not backcompat_paths else list(backcompat_paths.values())[0] ) return self._load_single_input(path, context, backcompat_path) else: # we are dealing with multiple partitions of an asset type_annotation = context.dagster_type.typing_type if type_annotation != Any and not is_dict_type(type_annotation): check.failed( "Loading an input that corresponds to multiple partitions, but the" " type annotation on the op input is not a dict, Dict, Mapping, or" f" Any: is '{type_annotation}'." ) return self._load_multiple_inputs(context) def handle_output(self, context: OutputContext, obj: Any): if context.dagster_type.typing_type == type(None): check.invariant( obj is None, "Output had Nothing type or 'None' annotation, but handle_output received" f" value that was not None and was of type {type(obj)}.", ) return None if context.has_asset_partitions: paths = self._get_paths_for_partitions(context) assert len(paths) == 1 path = list(paths.values())[0] else: path = self._get_path(context) self.make_directory(path.parent) context.log.debug(self.get_writing_output_log_message(path)) self.dump_to_path(context=context, obj=obj, path=path) metadata = {"path": MetadataValue.path(str(path))} custom_metadata = self.get_metadata(context=context, obj=obj) metadata.update(custom_metadata) # type: ignore context.add_output_metadata(metadata)
def is_dict_type(type_obj) -> bool: if type_obj == dict: return True if hasattr(type_obj, "__origin__") and type_obj.__origin__ in (dict, Dict, Mapping): return True return False