You are viewing an outdated version of the documentation.

This documentation is for an older version (1.4.7) of Dagster. You can view the version of this page from our latest release below.

Source code for dagster_databricks.databricks

import base64
import logging
import time
from typing import Any, Mapping, Optional

import dagster
import dagster._check as check
import dagster_pyspark
import requests.exceptions
from dagster._annotations import public
from databricks_api import DatabricksAPI
from databricks_cli.sdk import ApiClient, ClusterService, DbfsService, JobsService

import dagster_databricks

from .types import (
    DatabricksRunLifeCycleState,
    DatabricksRunResultState,
    DatabricksRunState,
)
from .version import __version__

# wait at most 24 hours by default for run execution
DEFAULT_RUN_MAX_WAIT_TIME_SEC = 24 * 60 * 60


[docs]class DatabricksError(Exception): pass
[docs]class DatabricksClient: """A thin wrapper over the Databricks REST API.""" def __init__(self, host: str, token: str, workspace_id: Optional[str] = None): self.host = host self.workspace_id = workspace_id # TODO: This is the old shim client that we were previously using. Arguably this is # confusing for users to use since this is an unofficial wrapper around the documented # Databricks REST API. We should consider removing this in the future. self.client = DatabricksAPI(host=host, token=token) self.__setup_user_agent(self.client.client) # Expose an interface directly to the official Databricks API client. self._api_client = ApiClient(host=host, token=token) self.__setup_user_agent(self._api_client) def __setup_user_agent(self, client: ApiClient) -> None: """Overrides the user agent for the Databricks API client.""" client.default_headers["user-agent"] = f"dagster-databricks/{__version__}" @public @property def api_client(self) -> ApiClient: """Retrieve a reference to the underlying Databricks API client. For more information, see the `Databricks Python API <https://docs.databricks.com/dev-tools/python-api.html>`_. **Examples:** .. code-block:: python from dagster import op from databricks_cli.jobs.api import JobsApi from databricks_cli.runs.api import RunsApi @op(required_resource_keys={"databricks_client"}) def op1(context): # Initialize the Databricks Jobs API jobs_client = JobsApi(context.resources.databricks_client.api_client) runs_client = RunsApi(context.resources.databricks_client.api_client) # Example 1: Run a Databricks job with some parameters. jobs_client.run_now(...) # Example 2: Trigger a one-time run of a Databricks workload. runs_client.submit_run(...) # Example 3: Get an existing run. runs_client.get_run(...) # Example 4: Cancel a run. runs_client.cancel_run(...) Returns: ApiClient: The authenticated Databricks API client. """ return self._api_client def read_file(self, dbfs_path: str, block_size: int = 1024**2) -> bytes: """Read a file from DBFS to a **byte string**.""" if dbfs_path.startswith("dbfs://"): dbfs_path = dbfs_path[7:] data = b"" bytes_read = 0 dbfs_service = DbfsService(self.api_client) jdoc = dbfs_service.read(path=dbfs_path, length=block_size) data += base64.b64decode(jdoc["data"]) while jdoc["bytes_read"] == block_size: bytes_read += jdoc["bytes_read"] jdoc = dbfs_service.read(path=dbfs_path, offset=bytes_read, length=block_size) data += base64.b64decode(jdoc["data"]) return data def put_file( self, file_obj, dbfs_path: str, overwrite: bool = False, block_size: int = 1024**2 ) -> None: """Upload an arbitrary large file to DBFS. This doesn't use the DBFS `Put` API because that endpoint is limited to 1MB. """ if dbfs_path.startswith("dbfs://"): dbfs_path = dbfs_path[7:] dbfs_service = DbfsService(self.api_client) create_response = dbfs_service.create(path=dbfs_path, overwrite=overwrite) handle = create_response["handle"] block = file_obj.read(block_size) while block: data = base64.b64encode(block).decode("utf-8") dbfs_service.add_block(data=data, handle=handle) block = file_obj.read(block_size) dbfs_service.close(handle=handle) def get_run_state(self, databricks_run_id: int) -> "DatabricksRunState": """Get the state of a run by Databricks run ID. Return a `DatabricksRunState` object. Note that the `result_state` attribute may be `None` if the run hasn't yet terminated. """ run = JobsService(self.api_client).get_run(databricks_run_id) state = run["state"] result_state = ( DatabricksRunResultState(state.get("result_state")) if state.get("result_state") else None ) return DatabricksRunState( life_cycle_state=DatabricksRunLifeCycleState(state["life_cycle_state"]), result_state=result_state, state_message=state["state_message"], ) def poll_run_state( self, logger: logging.Logger, start_poll_time: float, databricks_run_id: int, max_wait_time_sec: float, verbose_logs: bool = True, ) -> bool: run_state = self.get_run_state(databricks_run_id) if run_state.has_terminated(): if run_state.is_successful(): logger.info(f"Run `{databricks_run_id}` completed successfully.") return True if run_state.is_skipped(): logger.info(f"Run `{databricks_run_id}` was skipped.") return True else: error_message = ( f"Run `{databricks_run_id}` failed with result state:" f" `{run_state.result_state}`. Message: {run_state.state_message}." ) logger.error(error_message) raise DatabricksError(error_message) else: if verbose_logs: logger.debug(f"Run `{databricks_run_id}` in state {run_state}.") if time.time() - start_poll_time > max_wait_time_sec: raise DatabricksError( f"Run `{databricks_run_id}` took more than {max_wait_time_sec}s to complete." " Failing the run." ) return False def wait_for_run_to_complete( self, logger: logging.Logger, databricks_run_id: int, poll_interval_sec: float, max_wait_time_sec: int, verbose_logs: bool = True, ) -> None: logger.info(f"Waiting for Databricks run `{databricks_run_id}` to complete...") start_poll_time = time.time() while True: if self.poll_run_state( logger=logger, start_poll_time=start_poll_time, databricks_run_id=databricks_run_id, max_wait_time_sec=max_wait_time_sec, verbose_logs=verbose_logs, ): return time.sleep(poll_interval_sec)
class DatabricksJobRunner: """Submits jobs created using Dagster config to Databricks, and monitors their progress. Attributes: host (str): Databricks host, e.g. https://uksouth.azuredatabricks.net token (str): Databricks token """ def __init__( self, host: str, token: str, poll_interval_sec: float = 5, max_wait_time_sec: int = DEFAULT_RUN_MAX_WAIT_TIME_SEC, ): self.host = check.str_param(host, "host") self.token = check.str_param(token, "token") self.poll_interval_sec = check.numeric_param(poll_interval_sec, "poll_interval_sec") self.max_wait_time_sec = check.int_param(max_wait_time_sec, "max_wait_time_sec") self._client: DatabricksClient = DatabricksClient(host=self.host, token=self.token) @property def client(self) -> DatabricksClient: """Return the underlying `DatabricksClient` object.""" return self._client def submit_run(self, run_config: Mapping[str, Any], task: Mapping[str, Any]) -> int: """Submit a new run using the 'Runs submit' API.""" existing_cluster_id = run_config["cluster"].get("existing") new_cluster = run_config["cluster"].get("new") # The Databricks API needs different keys to be present in API calls depending # on new/existing cluster, so we need to process the new_cluster # config first. if new_cluster: new_cluster = new_cluster.copy() nodes = new_cluster.pop("nodes") if "instance_pool_id" in nodes: new_cluster["instance_pool_id"] = nodes["instance_pool_id"] else: node_types = nodes["node_types"] new_cluster["node_type_id"] = node_types["node_type_id"] if "driver_node_type_id" in node_types: new_cluster["driver_node_type_id"] = node_types["driver_node_type_id"] cluster_size = new_cluster.pop("size") if "num_workers" in cluster_size: new_cluster["num_workers"] = cluster_size["num_workers"] else: new_cluster["autoscale"] = cluster_size["autoscale"] tags = new_cluster.get("custom_tags", []) tags.append({"key": "__dagster_version", "value": dagster.__version__}) new_cluster["custom_tags"] = tags check.invariant( existing_cluster_id is not None or new_cluster is not None, "Invalid value for run_config.cluster", ) # We'll always need some libraries, namely dagster/dagster_databricks/dagster_pyspark, # since they're imported by our scripts. # Add them if they're not already added by users in config. libraries = list(run_config.get("libraries", [])) install_default_libraries = run_config.get("install_default_libraries", True) if install_default_libraries: python_libraries = { x["pypi"]["package"].split("==")[0].replace("_", "-") for x in libraries if "pypi" in x } for library_name, library in [ ("dagster", dagster), ("dagster-databricks", dagster_databricks), ("dagster-pyspark", dagster_pyspark), ]: if library_name not in python_libraries: libraries.append( {"pypi": {"package": f"{library_name}=={library.__version__}"}} ) # Only one task should be able to be chosen really; make sure of that here. check.invariant( sum( task.get(key) is not None for key in [ "notebook_task", "spark_python_task", "spark_jar_task", "spark_submit_task", ] ) == 1, "Multiple tasks specified in Databricks run", ) config = { "run_name": run_config.get("run_name"), "new_cluster": new_cluster, "existing_cluster_id": existing_cluster_id, "libraries": libraries, **task, } return JobsService(self.client.api_client).submit_run(**config)["run_id"] def retrieve_logs_for_run_id(self, log: logging.Logger, databricks_run_id: int): """Retrieve the stdout and stderr logs for a run.""" api_client = self.client.api_client run = JobsService(api_client).get_run(databricks_run_id) cluster = ClusterService(api_client).get_cluster(run["cluster_instance"]["cluster_id"]) log_config = cluster.get("cluster_log_conf") if log_config is None: log.warn( "Logs not configured for cluster {cluster} used for run {run}".format( cluster=cluster["cluster_id"], run=databricks_run_id ) ) return None if "s3" in log_config: logs_prefix = log_config["s3"]["destination"] log.warn("Retrieving S3 logs not yet implemented") return None elif "dbfs" in log_config: logs_prefix = log_config["dbfs"]["destination"] stdout = self.wait_for_dbfs_logs(log, logs_prefix, cluster["cluster_id"], "stdout") stderr = self.wait_for_dbfs_logs(log, logs_prefix, cluster["cluster_id"], "stderr") return stdout, stderr def wait_for_dbfs_logs( self, log: logging.Logger, prefix, cluster_id, filename, waiter_delay: int = 10, waiter_max_attempts: int = 10, ) -> Optional[str]: """Attempt up to `waiter_max_attempts` attempts to get logs from DBFS.""" path = "/".join([prefix, cluster_id, "driver", filename]) log.info(f"Retrieving logs from {path}") num_attempts = 0 while num_attempts <= waiter_max_attempts: try: logs = self.client.read_file(path) return logs.decode("utf-8") except requests.exceptions.HTTPError: num_attempts += 1 time.sleep(waiter_delay) log.warn("Could not retrieve cluster logs!")