You are viewing an outdated version of the documentation.

This documentation is for an older version (1.4.7) of Dagster. You can view the version of this page from our latest release below.

Source code for dagster_gcp.gcs.file_manager

import io
import uuid
from contextlib import contextmanager
from typing import Optional

import dagster._check as check
from dagster._core.storage.file_manager import (
    FileHandle,
    FileManager,
    TempfileManager,
    check_file_like_obj,
)
from google.cloud import storage


[docs]class GCSFileHandle(FileHandle): """A reference to a file on GCS.""" def __init__(self, gcs_bucket: str, gcs_key: str): self._gcs_bucket = check.str_param(gcs_bucket, "gcs_bucket") self._gcs_key = check.str_param(gcs_key, "gcs_key") @property def gcs_bucket(self) -> str: """str: The name of the GCS bucket.""" return self._gcs_bucket @property def gcs_key(self) -> str: """str: The GCS key.""" return self._gcs_key @property def path_desc(self) -> str: """str: The file's GCS URL.""" return self.gcs_path @property def gcs_path(self) -> str: """str: The file's GCS URL.""" return f"gs://{self.gcs_bucket}/{self.gcs_key}"
class GCSFileManager(FileManager): def __init__(self, client, gcs_bucket, gcs_base_key): self._client = check.inst_param(client, "client", storage.client.Client) self._gcs_bucket = check.str_param(gcs_bucket, "gcs_bucket") self._gcs_base_key = check.str_param(gcs_base_key, "gcs_base_key") self._local_handle_cache = {} self._temp_file_manager = TempfileManager() def copy_handle_to_local_temp(self, file_handle): self._download_if_not_cached(file_handle) return self._get_local_path(file_handle) def _download_if_not_cached(self, file_handle): if not self._file_handle_cached(file_handle): # instigate download temp_file_obj = self._temp_file_manager.tempfile() temp_name = temp_file_obj.name bucket_obj = self._client.bucket(file_handle.gcs_bucket) bucket_obj.blob(file_handle.gcs_key).download_to_file(temp_file_obj) self._local_handle_cache[file_handle.gcs_path] = temp_name return file_handle @contextmanager def read(self, file_handle, mode="rb"): check.inst_param(file_handle, "file_handle", GCSFileHandle) check.str_param(mode, "mode") check.param_invariant(mode in {"r", "rb"}, "mode") self._download_if_not_cached(file_handle) encoding = None if mode == "rb" else "utf-8" with open(self._get_local_path(file_handle), mode, encoding=encoding) as file_obj: yield file_obj def _file_handle_cached(self, file_handle): return file_handle.gcs_path in self._local_handle_cache def _get_local_path(self, file_handle): return self._local_handle_cache[file_handle.gcs_path] def read_data(self, file_handle): with self.read(file_handle, mode="rb") as file_obj: return file_obj.read() def write_data(self, data, ext=None, key: Optional[str] = None): key = check.opt_str_param(key, "key", default=str(uuid.uuid4())) check.inst_param(data, "data", bytes) return self.write(io.BytesIO(data), mode="wb", key=key, ext=ext) def write(self, file_obj, mode="wb", ext=None, key: Optional[str] = None): key = check.opt_str_param(key, "key", default=str(uuid.uuid4())) check_file_like_obj(file_obj) gcs_key = self.get_full_key(key + (("." + ext) if ext is not None else "")) bucket_obj = self._client.bucket(self._gcs_bucket) bucket_obj.blob(gcs_key).upload_from_file(file_obj) return GCSFileHandle(self._gcs_bucket, gcs_key) def get_full_key(self, file_key): return f"{self._gcs_base_key}/{file_key}" def delete_local_temp(self): self._temp_file_manager.close()