import copy
import os
from pathlib import Path
from typing import Any, Dict, Optional, Union
from runhouse.globals import rns_client
from runhouse.logger import get_logger
from runhouse.resources.hardware.cluster import Cluster
from runhouse.resources.secrets.provider_secrets.provider_secret import ProviderSecret
logger = get_logger(__name__)
[docs]class SSHSecret(ProviderSecret):
"""
.. note::
To create a SSHSecret, please use the factory method :func:`provider_secret` with ``provider="ssh"``.
"""
_DEFAULT_CREDENTIALS_PATH = "~/.ssh"
_PROVIDER = "ssh"
_DEFAULT_KEY = "id_rsa"
def __init__(
self,
name: Optional[str] = None,
provider: Optional[str] = None,
values: Dict = {},
path: str = None,
key: str = None,
dryrun: bool = True,
**kwargs,
):
self.key = (
key or os.path.basename(path) if path else (name or self._DEFAULT_KEY)
)
super().__init__(
name=name, provider=provider, values=values, path=path, dryrun=dryrun
)
if self.path == self._DEFAULT_CREDENTIALS_PATH:
self.path = str(Path(self._DEFAULT_CREDENTIALS_PATH) / self.key)
@staticmethod
def from_config(config: dict, dryrun: bool = False, _resolve_children: bool = True):
# try block if for the case we are trying to load a shared secret.
return SSHSecret(**config, dryrun=dryrun)
def save(
self,
name: str = None,
save_values: bool = True,
headers: Optional[Dict] = None,
folder: str = None,
):
if name:
self.name = name
elif not self.name:
self.name = f"ssh-{self.key}"
if self.path:
try:
rel_path = "~" / Path(self.path).relative_to(Path.home())
self.path = str(rel_path)
except (ValueError, RuntimeError):
pass
return super().save(
save_values=save_values,
headers=headers or rns_client.request_headers(),
folder=folder,
)
def _write_to_file(
self,
path: str,
values: Dict = None,
overwrite: bool = False,
write_config: bool = True,
):
priv_key_path = path
priv_key_path = Path(os.path.expanduser(priv_key_path))
pub_key_path = Path(f"{os.path.expanduser(priv_key_path)}.pub")
values = values or self.values
if priv_key_path.exists() and pub_key_path.exists():
if values == self._from_path(path=path):
logger.info(f"Secrets already exist in {path}. Skipping.")
self.path = path
return self
logger.warning(
f"SSH Secrets for {self.name or self.key} already exist in {path}. "
"Automatically overriding SSH keys is not supported by Runhouse. "
"Please manually edit these files."
)
self.path = path
return self
priv_key_path.parent.mkdir(parents=True, exist_ok=True)
private_key = values.get("private_key")
if private_key is not None and not priv_key_path.exists():
priv_key_path.write_text(private_key)
priv_key_path.chmod(0o600)
public_key = values.get("public_key")
if public_key is not None and not pub_key_path.exists():
pub_key_path.write_text(public_key)
pub_key_path.chmod(0o600)
new_secret = copy.deepcopy(self)
new_secret._values = None
new_secret.path = path
new_secret.name = f"ssh-{os.path.basename(path)}"
if write_config:
try:
new_secret._add_to_rh_config(val=path)
except TypeError:
pass
return new_secret
def _from_path(self, path: str):
if path == self._DEFAULT_CREDENTIALS_PATH or path == os.path.expanduser(
self._DEFAULT_CREDENTIALS_PATH
):
path = f"{path}/{self.key}"
return self.extract_secrets_from_path(path)
@staticmethod
def extract_secrets_from_path(path: str) -> Dict:
pub_key_path = os.path.expanduser(f"{path}.pub")
priv_key_path = os.path.expanduser(path)
if not (os.path.exists(pub_key_path) and os.path.exists(priv_key_path)):
return {}
pub_key = Path(pub_key_path).read_text()
priv_key = Path(priv_key_path).read_text()
return {"public_key": pub_key, "private_key": priv_key}
def _file_to(
self,
key: str,
system: Union[str, Cluster],
path: Union[str, Path] = None,
values: Any = None,
):
if self.path:
remote_priv_file = self.path
# pub_key_path = f"{path}.pub"
system.call(key, "_write_to_file", path=remote_priv_file, values=values)
system.run([f"chmod 600 {path}"])
else:
system.call(key, "_write_to_file", path=path, values=values)
remote_priv_file = path
return remote_priv_file