import copy
import json
import os
from pathlib import Path
from typing import Dict, List, Optional, Union
from runhouse.globals import configs, rns_client
from runhouse.logger import get_logger
from runhouse.resources.hardware import _get_cluster_from, Cluster
from runhouse.resources.resource import Resource
from runhouse.resources.secrets.utils import _delete_vault_secrets, load_config
from runhouse.rns.utils.api import load_resp_content, read_resp_data
from runhouse.utils import generate_default_name
logger = get_logger(__name__)
[docs]class Secret(Resource):
RESOURCE_TYPE = "secret"
USER_ENDPOINT = "user/secret"
GROUP_ENDPOINT = "group/secret"
DEFAULT_DIR = "~/.rh/secrets"
[docs] def __init__(
self,
name: Optional[str],
values: Dict = None,
dryrun: bool = False,
**kwargs,
):
"""
Runhouse Secret object.
.. note::
To create a Secret, please use one of the factory methods.
"""
super().__init__(name=name, dryrun=dryrun)
self._values = values
@property
def values(self):
return self._values
def config(self, condensed: bool = True, values: bool = True):
config = super().config(condensed)
if self._values and values:
config.update(
{
"values": self._values,
}
)
return config
@staticmethod
def _write_shared_secret_to_local(config):
import runhouse as rh
new_creds_values = config["values"]
folder_name = config["name"].replace("/", "_")
path = f"{Secret.DEFAULT_DIR}/{folder_name}"
private_key_value, public_key_value = new_creds_values.get(
"private_key"
), new_creds_values.get("public_key")
private_key_path = public_key_path = Path(path).expanduser()
if private_key_value:
if not private_key_path.exists():
os.makedirs(str(private_key_path))
private_file_path = private_key_path / "ssh-key"
with open(str(private_file_path), "w") as f:
f.write(private_key_value)
private_file_path.chmod(0o600)
if public_key_value:
public_file_path = public_key_path / "ssh-key.pub"
with open(str(public_key_path / "ssh-key.pub"), "w") as f:
f.write(public_key_value)
public_file_path.chmod(0o600)
if private_key_value and public_key_value:
new_creds_values = {
"ssh_private_key": str(private_key_path / "ssh-key"),
"ssh_public_key": str(public_key_path / "ssh-key.pub"),
}
if private_key_value and new_creds_values.get("ssh_user"):
new_creds_values = {
"ssh_private_key": str(private_key_path / "ssh-key"),
"ssh_user": new_creds_values.get("ssh_user"),
}
return rh.secret(
values=new_creds_values, name=f"loaded_secret_{config['name']}"
)
[docs] @staticmethod
def from_config(config: dict, dryrun: bool = False, _resolve_children: bool = True):
if "provider" in config:
from runhouse.resources.secrets.provider_secrets.providers import (
_get_provider_class,
)
provider_class = _get_provider_class(config["provider"])
return provider_class.from_config(config, dryrun=dryrun)
# checks if the config is a of a shared secret
current_user = configs.username
owner_user = config["owner"]["username"] if "owner" in config.keys() else None
if owner_user and current_user != owner_user and config["values"]:
return Secret._write_shared_secret_to_local(config)
return Secret(**config, dryrun=dryrun)
[docs] @classmethod
def from_name(
cls,
name,
load_from_den: bool = True,
dryrun: bool = False,
_alt_options: Dict = None,
_resolve_children: bool = True,
):
try:
config = load_config(name, cls.USER_ENDPOINT)
if config:
return cls.from_config(config=config, dryrun=dryrun)
except ValueError:
pass
if name in cls.builtin_providers(as_str=True):
from runhouse.resources.secrets.provider_secrets.providers import (
_get_provider_class,
)
provider_class = _get_provider_class(name)
return provider_class(provider=name, dryrun=dryrun)
raise ValueError(f"Could not locate secret {name}")
[docs] @classmethod
def builtin_providers(cls, as_str: bool = False) -> List:
"""Return list of all Runhouse providers (as class objects) supported out of the box.
Args:
as_str (bool, optional): Whether to return the providers as a string or as a class.
(Default: ``False``)
"""
from runhouse.resources.secrets.provider_secrets.providers import (
_str_to_provider_class,
)
if as_str:
return list(_str_to_provider_class.keys())
return list(_str_to_provider_class.values())
[docs] @classmethod
def vault_secrets(cls, headers: Optional[Dict] = None) -> List[str]:
"""Get secret names that are stored in Vault"""
uri = f"{rns_client.api_server_url}/{cls.USER_ENDPOINT}"
resp = rns_client.session.get(
uri,
headers=headers or rns_client.request_headers(),
)
if resp.status_code not in [200, 404]:
raise Exception(
f"Received [{resp.status_code}] from Den GET '{uri}': Failed to download secrets from Vault."
)
response = read_resp_data(resp)
return list(response.keys())
[docs] @classmethod
def local_secrets(cls, names: List[str] = None) -> Dict[str, "Secret"]:
"""Get list of local secrets.
Args:
names (List[str], optional): Specific names of local secrets to retrieve. If ``None``, returns all
locally detected secrets. (Default: ``None``)
"""
if not os.path.exists(os.path.expanduser("~/.rh/secrets")):
return {}
all_names = [
Path(file).stem
for file in os.listdir(os.path.expanduser("~/.rh/secrets"))
if file.endswith("json")
]
names = [name for name in names if name in all_names] if names else all_names
secrets = {}
for name in names:
path = os.path.expanduser(f"~/.rh/secrets/{name}.json")
try:
with open(path, "r") as f:
config = json.load(f)
if config["name"].startswith("~") or config["name"].startswith("^"):
config["name"] = config["name"][2:]
secrets[name] = Secret.from_config(config)
except json.JSONDecodeError:
# Ignore any empty / corrupted files
continue
return secrets
# TODO: refactor this code to reuse rns_client save_config code instead of rewriting
[docs] def save(
self,
name: str = None,
save_values: bool = True,
headers: Optional[Dict] = None,
folder: str = None,
):
"""
Save the secret config to Den. Save the secret values into Vault if the user is logged in,
or to local if not or if the resource is a local resource.
Args:
name (str, optional): Name to save the secret resource as.
save_values (str, optional): Whether to save the values of the secret to Vault in addition
to saving the metadata to Den. (Default: ``True``)
headers (Dict, optional): Request headers to provide for the request to RNS. Contains the
user's auth token. Example: ``{"Authorization": f"Bearer {token}"}`` (Default: ``None``)
folder (str, optional): If specified, save the secret to that folder in Den (e.g. saving secrets
for a cluster associated with an organization). (Default: ``None``)
"""
if name:
self.name = name
elif not self.name:
raise ValueError("A resource must have a name to be saved.")
self._rns_folder = folder or self._rns_folder or rns_client.current_folder
config = self.config()
config["name"] = self.rns_address
if "values" in config:
# don't save values into Den config
del config["values"]
headers = headers or rns_client.request_headers()
# Save metadata to Den
if self.rns_address.startswith("/"):
logger.info(f"Saving config for {self.rns_address} to Den")
payload = rns_client.resource_request_payload(config)
uri = f"{rns_client.api_server_url}/resource"
resp = rns_client.session.post(
uri,
data=json.dumps(payload),
headers=headers,
)
# If resource config hasn't changed (i.e. nothing to update) will return a 422
if resp.status_code not in [200, 422]:
raise Exception(
f"Received [{resp.status_code}] from Den POST '{uri}': Failed to save metadata to Den: {load_resp_content(resp)}"
)
if save_values and self.values:
logger.info(f"Saving secrets for {self.rns_address} to Vault")
resource_uri = rns_client.resource_uri(self.rns_address)
uri = f"{rns_client.api_server_url}/{self.USER_ENDPOINT}/{resource_uri}"
resp = rns_client.session.put(
uri,
data=json.dumps(
{"name": self.rns_address, "data": {"values": self.values}}
),
headers=headers,
)
if resp.status_code != 200:
raise Exception(
f"Received [{resp.status_code}] from Den PUT '{uri}': Failed to put resources in Vault: {load_resp_content(resp)}"
)
else:
config_path = os.path.expanduser(f"~/.rh/secrets/{self.name}.json")
os.makedirs(os.path.dirname(config_path), exist_ok=True)
if save_values:
config["values"] = self.values
with open(config_path, "w") as f:
json.dump(config, f, indent=4)
logger.info(f"Saving config for {self.rns_address} to: {config_path}")
return self
[docs] def delete(self, headers: Optional[Dict] = None):
"""Delete the secret config from Den and from Vault/local."""
if not (self.in_vault() or self.in_local()):
logger.warning(
"Can not delete a secret that has not been saved down to Vault or local."
)
else:
if self.rns_address and self.rns_address.startswith("/"):
self._delete_secret_configs(headers)
else:
self._delete_local_config()
configs.delete_provider(self.name)
def _delete_local_config(self):
config_path = os.path.expanduser(f"~/.rh/secrets/{self.name}.json")
if os.path.exists(config_path):
os.remove(config_path)
def _delete_secret_configs(self, headers: Optional[Dict] = None):
headers = headers or rns_client.request_headers()
# Delete secrets in Vault
resource_uri = rns_client.resource_uri(self.rns_address)
_delete_vault_secrets(resource_uri, self.USER_ENDPOINT, headers=headers)
# Delete RNS data for resource
uri = f"{rns_client.api_server_url}/resource/{resource_uri}"
resp = rns_client.session.delete(
uri,
headers=headers,
)
if resp.status_code != 200:
logger.error(
f"Received [{resp.status_code}] from Den DELETE '{uri}': Failed to delete secret resource from Den: {load_resp_content(resp)}"
)
[docs] def to(
self,
system: Union[str, Cluster],
name: Optional[str] = None,
env: Optional["Env"] = None,
):
"""Return a copy of the secret on a system.
Args:
system (str or Cluster): Cluster to send the secret to
name (str, optional): Name to assign the resource on the cluster.
env (Env, optional): Env to send the secret to.
Example:
>>> secret.to(my_cluster, path=secret.path)
"""
new_secret = copy.deepcopy(self)
new_secret.name = name or self.name or generate_default_name(prefix="secret")
system = _get_cluster_from(system)
if system.on_this_cluster():
new_secret.pin()
else:
env = env or system.default_env
system.put_resource(new_secret, env=env)
return new_secret
[docs] def in_local(self):
"""Whether the secret config is stored locally (as opposed to Vault)."""
path = os.path.expanduser(f"~/.rh/secrets/{self.name}.json")
if os.path.exists(os.path.expanduser(path)):
return True
return False
[docs] def in_vault(self, headers=None):
"""Whether the secret is stored in Vault"""
if not self.rns_address:
return False
resource_uri = rns_client.resource_uri(self.rns_address)
resp = rns_client.session.get(
f"{rns_client.api_server_url}/{self.USER_ENDPOINT}/{resource_uri}",
headers=headers or rns_client.request_headers(),
)
if resp.status_code != 200:
return False
response = read_resp_data(resp)
# TODO: switch this to use self.name once vault updates
if response and response[list(response.keys())[0]]:
return True
return False
def is_present(self):
if self.values:
return True
return False