"""Fetches a file from a gs bucket""" _GSUTIL_FILE_BUILD = """\ package(default_visibility = ["//visibility:public"]) filegroup( name = "file", srcs = ["{}"], ) """ def run_command(ctx, cmd): """Runs a cli command. Args: ctx: the repository context cmd: the command to run as a list of strings Returns: returns the exec_result object returned by the repository_ctx.execute() method """ result = ctx.execute(cmd, timeout = 1800) if result.return_code != 0: fail("Failed to run {}: {}".format(cmd, result.stderr)) return result def gsutil_download_file(ctx, path): """Downloads a file from a gs bucket to 'path'. Args: ctx: the repository context path: path to download the file to """ # Create download directory if it doesn't exist mkdir_path = ctx.which("mkdir") mkdir_cmd = [mkdir_path, "-p", ctx.path(path).dirname] run_command(ctx, mkdir_cmd) # Download the bucket file gsutil_path = ctx.which("gsutil") src_uri = "gs://{}/{}".format(ctx.attr.bucket, ctx.attr.path) download_cmd = [gsutil_path, "cp", src_uri, path] ctx.report_progress("Downloading {}.".format(src_uri)) run_command(ctx, download_cmd) def validate_checksum(ctx, path, expected_sha256): """Validates the checksum of the downloaded file. Args: ctx: the repository context path: the downloaded file path expected_sha256: the expected sha256 checksum """ if ctx.attr.sha256 == "": return sha256_path = ctx.which("sha256sum") ctx.report_progress("Checksumming {}.".format(path)) sha256_result = run_command(ctx, [sha256_path, path]) sha256 = sha256_result.stdout.split(" ")[0] if sha256 != expected_sha256: fail("Checksum mismatch for {}, expected {}, got {}.".format( path, expected_sha256, sha256, )) def validate_download_path(ctx, path): """Validates the download path. Args: ctx: the repository context path: the expected download file path """ repo_root = ctx.path(".") forbidden_files = [ repo_root, ctx.path("WORKSPACE"), ctx.path("BUILD"), ctx.path("BUILD.bazel"), ctx.path("file/BUILD"), ctx.path("file/BUILD.bazel"), ] if path in forbidden_files or not str(path).startswith(str(repo_root)): fail("'%s' cannot be used as downloaded_file_path in gsutil_file" % ctx.attr.downloaded_file_path) def _gsutil_file_impl(ctx): """Implementation of the gsutil_file rule.""" # Prepare download path downloaded_file_path = ctx.attr.downloaded_file_path download_path = ctx.path("file/" + downloaded_file_path) # Validate download path validate_download_path(ctx, download_path) # Download gsutil_download_file(ctx, download_path) # Verify validate_checksum(ctx, download_path, ctx.attr.sha256) # Create filegroup ctx.file("file/BUILD", _GSUTIL_FILE_BUILD.format(downloaded_file_path)) _gsutil_file_attrs = { "bucket": attr.string( mandatory = True, doc = "Google storage bucket name", ), "path": attr.string( mandatory = True, doc = "Relative path to the archive file within the bucket", ), "downloaded_file_path": attr.string( default = "downloaded", doc = "Path assigned to the downloaded file", ), "sha256": attr.string( doc = "The expected SHA-256 of the downloaded file", ), } gsutil_file = repository_rule( implementation = _gsutil_file_impl, attrs = _gsutil_file_attrs, doc = """Downloads a file from a google storage bucket and makes it available to be used as a file group. Examples: To get my_deb.deb from gs://my_google_storage_bucket, add this to your WORKSPACE file: ```python load("//hack/build/rules/google_storage:gsutil_file.bzl", "gsutil_file") gsutil_file( name = "my_deb", bucket = "my_google_storage_bucket", path = "my_deb.deb", ) ``` Targets would specify `@my_deb//file` as a dependency to depend on this file. """, )