Source code for runhouse.resources.secrets.secret

import copy
import json
import os
from pathlib import Path

from typing import Dict, List, Optional, Union

import yaml

from runhouse.globals import configs, rns_client
from runhouse.logger import get_logger

from runhouse.resources.hardware import _get_cluster_from, Cluster
from runhouse.resources.resource import Resource
from runhouse.resources.secrets.utils import _delete_vault_secrets, load_config
from runhouse.rns.utils.api import load_resp_content, read_resp_data
from runhouse.utils import generate_default_name

logger = get_logger(__name__)


[docs]class Secret(Resource): RESOURCE_TYPE = "secret" USER_ENDPOINT = "user/secret" GROUP_ENDPOINT = "group/secret" DEFAULT_DIR = "~/.rh/secrets"
[docs] def __init__( self, name: Optional[str], values: Dict = None, dryrun: bool = False, **kwargs, ): """ Runhouse Secret object. .. note:: To create a Secret, please use one of the factory methods. """ super().__init__(name=name, dryrun=dryrun) self._values = values
@property def values(self): return self._values def config(self, condensed: bool = True, values: bool = True): config = super().config(condensed) if self._values and values: config.update( { "values": self._values, } ) return config @staticmethod def _write_shared_secret_to_local(config): import runhouse as rh new_creds_values = config["values"] folder_name = config["name"].replace("/", "_") path = f"{Secret.DEFAULT_DIR}/{folder_name}" private_key_value, public_key_value = new_creds_values.get( "private_key" ), new_creds_values.get("public_key") private_key_path = public_key_path = Path(path).expanduser() if private_key_value: if not private_key_path.exists(): os.makedirs(str(private_key_path)) private_file_path = private_key_path / "ssh-key" with open(str(private_file_path), "w") as f: f.write(private_key_value) private_file_path.chmod(0o600) if public_key_value: public_file_path = public_key_path / "ssh-key.pub" with open(str(public_key_path / "ssh-key.pub"), "w") as f: f.write(public_key_value) public_file_path.chmod(0o600) if private_key_value and public_key_value: new_creds_values = { "ssh_private_key": str(private_key_path / "ssh-key"), "ssh_public_key": str(public_key_path / "ssh-key.pub"), } if private_key_value and new_creds_values.get("ssh_user"): new_creds_values = { "ssh_private_key": str(private_key_path / "ssh-key"), "ssh_user": new_creds_values.get("ssh_user"), } return rh.secret( values=new_creds_values, name=f"loaded_secret_{config['name']}" )
[docs] @staticmethod def from_config(config: dict, dryrun: bool = False, _resolve_children: bool = True): if "provider" in config: from runhouse.resources.secrets.provider_secrets.providers import ( _get_provider_class, ) provider_class = _get_provider_class(config["provider"]) return provider_class.from_config(config, dryrun=dryrun) # checks if the config is a of a shared secret current_user = configs.username owner_user = config["owner"]["username"] if "owner" in config.keys() else None if owner_user and current_user != owner_user and config["values"]: return Secret._write_shared_secret_to_local(config) return Secret(**config, dryrun=dryrun)
[docs] @classmethod def from_name( cls, name, provider: str = None, load_from_den: bool = True, dryrun: bool = False, _alt_options: Dict = None, _resolve_children: bool = True, ): try: config = load_config(name, cls.USER_ENDPOINT) if config: return cls.from_config(config=config, dryrun=dryrun) except ValueError: pass provider = provider or name if provider in cls.builtin_providers(as_str=True): from runhouse.resources.secrets.provider_secrets.providers import ( _get_provider_class, ) provider_class = _get_provider_class(provider) return provider_class(name=name, provider=provider, dryrun=dryrun) raise ValueError(f"Could not locate secret {name}")
[docs] @classmethod def builtin_providers(cls, as_str: bool = False) -> List: """Return list of all Runhouse providers (as class objects) supported out of the box. Args: as_str (bool, optional): Whether to return the providers as a string or as a class. (Default: ``False``) """ from runhouse.resources.secrets.provider_secrets.providers import ( _str_to_provider_class, ) if as_str: return list(_str_to_provider_class.keys()) return list(_str_to_provider_class.values())
[docs] @classmethod def vault_secrets(cls, headers: Optional[Dict] = None) -> List[str]: """Get secret names that are stored in Vault""" uri = f"{rns_client.api_server_url}/{cls.USER_ENDPOINT}" resp = rns_client.session.get( uri, headers=headers or rns_client.request_headers(), ) if resp.status_code not in [200, 404]: raise Exception( f"Received [{resp.status_code}] from Den GET '{uri}': Failed to download secrets from Vault." ) response = read_resp_data(resp) return list(response.keys())
[docs] @classmethod def local_secrets(cls, names: List[str] = None) -> Dict[str, "Secret"]: """Get list of local secrets. Args: names (List[str], optional): Specific names of local secrets to retrieve. If ``None``, returns all locally detected secrets. (Default: ``None``) """ if not os.path.exists(os.path.expanduser("~/.rh/secrets")): return {} all_names = [ Path(file).stem for file in os.listdir(os.path.expanduser("~/.rh/secrets")) if file.endswith("json") ] names = [name for name in names if name in all_names] if names else all_names secrets = {} for name in names: path = os.path.expanduser(f"~/.rh/secrets/{name}.json") try: with open(path, "r") as f: config = json.load(f) if config["name"].startswith("~") or config["name"].startswith("^"): config["name"] = config["name"][2:] secrets[name] = Secret.from_config(config) except json.JSONDecodeError: # Ignore any empty / corrupted files continue return secrets
[docs] @classmethod def extract_provider_secrets(cls, names: List[str] = None) -> Dict[str, "Secret"]: """Extract secret values from providers. Returns a Dict mapping the provider name to Secret. Args: names (List[str]): List of provider names to extract secrets for. If ``None``, returns secrets for all detected providers. (Default: ``None``) """ from runhouse.resources.secrets.provider_secrets.providers import ( _str_to_provider_class, ) from runhouse.resources.secrets.secret_factory import provider_secret secrets = {} names = names or _str_to_provider_class.keys() for provider in names: if provider in ["ssh", "sky"]: continue try: secret = provider_secret(provider=provider) if provider == "sky": provider = f"ssh-{secret.key}" secrets[provider] = secret except ValueError: continue # locally configured ssh secrets if "ssh" in names: default_ssh_folder = "~/.ssh" ssh_files = os.listdir(os.path.expanduser(default_ssh_folder)) for file in ssh_files: if file != "sky-key" and f"{file}.pub" in ssh_files: name = f"ssh-{file}" secret = provider_secret( provider="ssh", name=name, path=os.path.join(default_ssh_folder, file), ) secrets[name] = secret return secrets
# TODO: refactor this code to reuse rns_client save_config code instead of rewriting
[docs] def save( self, name: str = None, save_values: bool = True, headers: Optional[Dict] = None, folder: str = None, ): """ Save the secret config to Den. Save the secret values into Vault if the user is logged in, or to local if not or if the resource is a local resource. Args: name (str, optional): Name to save the secret resource as. save_values (str, optional): Whether to save the values of the secret to Vault in addition to saving the metadata to Den. (Default: ``True``) headers (Dict, optional): Request headers to provide for the request to Den. Contains the user's auth token. Example: ``{"Authorization": f"Bearer {token}"}`` (Default: ``None``) folder (str, optional): If specified, save the secret to that folder in Den (e.g. saving secrets for a cluster associated with an organization). (Default: ``None``) """ if name: self.name = name elif not self.name: raise ValueError("A resource must have a name to be saved.") self._rns_folder = folder or self._rns_folder or rns_client.current_folder config = self.config() config["name"] = self.rns_address if "values" in config: # don't save values into Den config del config["values"] headers = headers or rns_client.request_headers() # Save metadata to Den if self.rns_address.startswith("/"): logger.info(f"Saving config for {self.rns_address} to Den") payload = rns_client.resource_request_payload(config) uri = f"{rns_client.api_server_url}/resource" resp = rns_client.session.post( uri, data=json.dumps(payload), headers=headers, ) # If resource config hasn't changed (i.e. nothing to update) will return a 422 if resp.status_code not in [200, 422]: raise Exception( f"Received [{resp.status_code}] from Den POST '{uri}': Failed to save metadata to Den: {load_resp_content(resp)}" ) if save_values and self.values: logger.info(f"Saving secrets for {self.rns_address} to Vault") resource_uri = rns_client.resource_uri(self.rns_address) uri = f"{rns_client.api_server_url}/{self.USER_ENDPOINT}/{resource_uri}" resp = rns_client.session.put( uri, data=json.dumps( {"name": self.rns_address, "data": {"values": self.values}} ), headers=headers, ) if resp.status_code != 200: raise Exception( f"Received [{resp.status_code}] from Den PUT '{uri}': Failed to put resources in Vault: {load_resp_content(resp)}" ) else: config_path = os.path.expanduser(f"~/.rh/secrets/{self.name}.json") os.makedirs(os.path.dirname(config_path), exist_ok=True) if save_values: config["values"] = self.values with open(config_path, "w") as f: json.dump(config, f, indent=4) logger.info(f"Saving config for {self.rns_address} to: {config_path}") return self
[docs] def delete(self, headers: Optional[Dict] = None): """Delete the secret config from Den and from Vault/local.""" if not (self.in_vault() or self.in_local()): logger.warning( "Can not delete a secret that has not been saved down to Vault or local." ) else: if self.rns_address and self.rns_address.startswith("/"): self._delete_secret_configs(headers) else: self._delete_local_config() configs.delete_provider(self.name)
def _delete_local_config(self): config_path = os.path.expanduser(f"~/.rh/secrets/{self.name}.json") if os.path.exists(config_path): os.remove(config_path) def _delete_secret_configs(self, headers: Optional[Dict] = None): headers = headers or rns_client.request_headers() # Delete secrets in Vault resource_uri = rns_client.resource_uri(self.rns_address) _delete_vault_secrets(resource_uri, self.USER_ENDPOINT, headers=headers) # Delete Den data for resource uri = f"{rns_client.api_server_url}/resource/{resource_uri}" resp = rns_client.session.delete( uri, headers=headers, ) if resp.status_code != 200: logger.error( f"Received [{resp.status_code}] from Den DELETE '{uri}': Failed to delete secret resource from Den: {load_resp_content(resp)}" )
[docs] def to( self, system: Union[str, Cluster], name: Optional[str] = None, process: Optional[str] = None, ): """Return a copy of the secret on a system. Args: system (str or Cluster): Cluster to send the secret to name (str, optional): Name to assign the resource on the cluster. process (str, optional): Process on the cluster to send the secret to. Example: >>> secret.to(my_cluster, path=secret.path) """ new_secret = copy.deepcopy(self) new_secret.name = name or self.name or generate_default_name(prefix="secret") system = _get_cluster_from(system) if system.on_this_cluster(): new_secret.pin() else: system.put_resource(new_secret, process=process) return new_secret
def _write_to_file(self, path: str, overwrite: bool = False, format: str = "json"): full_path = os.path.expanduser(path) if os.path.exists(full_path) and not overwrite: logger.info( f"{path} already exists and overwrite set to `False`. Not overriding contents." ) elif format == "json": with open(full_path, "w+") as f: json.dump(self.values, f, indent=4) elif format == "yaml": with open(full_path, "w+") as f: yaml.safe_dump(self.values, f) else: logger.error( f"Only 'json' and 'yaml' formats currently supported, not {format}." )
[docs] def in_local(self): """Whether the secret config is stored locally (as opposed to Vault).""" path = os.path.expanduser(f"~/.rh/secrets/{self.name}.json") if os.path.exists(os.path.expanduser(path)): return True return False
[docs] def in_vault(self, headers=None): """Whether the secret is stored in Vault""" if not self.rns_address: return False resource_uri = rns_client.resource_uri(self.rns_address) resp = rns_client.session.get( f"{rns_client.api_server_url}/{self.USER_ENDPOINT}/{resource_uri}", headers=headers or rns_client.request_headers(), ) if resp.status_code != 200: return False response = read_resp_data(resp) # TODO: switch this to use self.name once vault updates if response and response[list(response.keys())[0]]: return True return False
def is_present(self): if self.values: return True return False