import asyncio
import contextlib
import subprocess
import time
import warnings
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
from typing import Any, Dict, List, Union
import requests
import rich.errors
try:
import sky
from sky import ClusterStatus as SkyClusterStatus
from sky.backends import backend_utils
except ImportError:
pass
from runhouse.constants import (
DEFAULT_HTTP_PORT,
DEFAULT_HTTPS_PORT,
DEFAULT_SERVER_PORT,
LOCAL_HOSTS,
)
from runhouse.globals import configs, obj_store, rns_client
from runhouse.logger import get_logger
from runhouse.resources.hardware.utils import (
_cluster_set_autostop_command,
ClusterStatus,
LauncherType,
pprint_launched_cluster_summary,
RunhouseDaemonStatus,
ServerConnectionType,
up_cluster_helper,
)
from .cluster import Cluster
from .launcher_utils import DenLauncher, LocalLauncher
logger = get_logger(__name__)
[docs]class OnDemandCluster(Cluster):
RESOURCE_TYPE = "cluster"
RECONNECT_TIMEOUT = 5
DEFAULT_KEYFILE = "~/.ssh/sky-key"
[docs] def __init__(
self,
name,
instance_type: str = None,
num_nodes: int = None,
provider: str = None,
dryrun: bool = False,
autostop_mins: int = None,
use_spot: bool = False,
memory: Union[int, str] = None,
disk_size: int = None,
num_cpus: Union[int, str] = None,
gpus: str = None,
open_ports: Union[int, str, List[int]] = None,
server_host: int = None,
server_port: int = None,
server_connection_type: str = None,
launcher: str = None,
ssl_keyfile: str = None,
ssl_certfile: str = None,
domain: str = None,
den_auth: bool = False,
region: str = None,
vpc_name: str = None,
sky_kwargs: Dict = None,
**kwargs, # We have this here to ignore extra arguments when calling from from_config
):
"""
On-demand `SkyPilot <https://github.com/skypilot-org/skypilot/>`__ Cluster.
.. note::
To build a cluster, please use the factory method :func:`cluster`.
"""
self.launcher = launcher or configs.launcher
super().__init__(
name=name,
server_host=server_host,
server_port=server_port,
server_connection_type=server_connection_type,
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile,
domain=domain,
den_auth=den_auth,
dryrun=dryrun,
**kwargs,
)
self.instance_type = instance_type
self.num_nodes = num_nodes
self.provider = provider or configs.get("default_provider")
self._autostop_mins = (
autostop_mins
if autostop_mins is not None
else configs.get("default_autostop")
)
self.open_ports = open_ports
self.use_spot = use_spot if use_spot is not None else configs.get("use_spot")
self.region = region
self.memory = memory
self.disk_size = disk_size
self._num_cpus = num_cpus
self._gpus = gpus
self.sky_kwargs = sky_kwargs or {}
self.vpc_name = vpc_name
self.compute_properties = {}
if kwargs.get("ips"):
self.compute_properties["ips"] = kwargs.get("ips")
self.compute_properties = {
**self.compute_properties,
**kwargs.get("compute_properties", {}),
**kwargs.get("launched_properties", {}),
}
self._docker_user = None
self._kube_namespace = kwargs.get(
"kube_namespace"
) or self.compute_properties.get("kube_namespace")
self._kube_context = kwargs.get("kube_context") or self.compute_properties.get(
"kube_context"
)
self._cluster_status = kwargs.get("cluster_status")
# Checks if state info is in local sky db, populates if so.
if not dryrun and not self.ips and self.launcher == LauncherType.LOCAL:
# Cluster status is set to INIT in the Sky DB right after starting, so we need to refresh once
self._update_from_sky_status(dryrun=True)
@property
def ips(self):
return self.compute_properties.get("ips", [])
@property
def internal_ips(self):
return self.compute_properties.get("internal_ips", [])
@property
def client(self):
try:
return super().client
except ValueError as e:
if not self.ips:
# Try loading in from local Sky DB
self._update_from_sky_status(dryrun=True)
if not self.ips:
raise ValueError(
f"Could not determine ips for ondemand cluster <{self.name}>. "
"Up the cluster with `cluster.up_if_not`."
)
return super().client
raise e
@property
def autostop_mins(self):
return self._autostop_mins
@autostop_mins.setter
def autostop_mins(self, mins):
self._autostop_mins = mins
if not self.is_up():
return
if self.on_this_cluster():
obj_store.set_cluster_config_value("autostop_mins", mins)
else:
self.call_client_method("set_settings", {"autostop_mins": mins})
if self.launcher == "local":
LocalLauncher.keep_warm(self, mins)
elif self.launcher == "den":
DenLauncher.keep_warm(self, mins)
@property
def image_id(self) -> str:
if self.image and self.image.image_id:
return self.image.image_id
return None
@property
def docker_user(self) -> str:
if self._docker_user:
return self._docker_user
# TODO detect whether this is a k8s cluster properly, and handle the user setting / SSH properly
# (e.g. SkyPilot's new KubernetesCommandRunner)
if not self.image_id or "docker:" not in self.image_id:
return None
if self.compute_properties.get("cloud") == "kubernetes":
return self.compute_properties.get(
"docker_user", self.compute_properties.get("ssh_user", "root")
)
from runhouse.resources.hardware.sky_command_runner import get_docker_user
if not self._creds:
return
self._docker_user = get_docker_user(self, self.ssh_properties)
return self._docker_user
@property
def cluster_status(self):
return self._cluster_status
@cluster_status.setter
def cluster_status(self, new_status: ClusterStatus):
self._cluster_status = new_status
def config(self, condensed=True):
config = super().config(condensed)
self.save_attrs_to_config(
config,
[
"instance_type",
"num_nodes",
"provider",
"open_ports",
"use_spot",
"region",
"memory",
"disk_size",
"vpc_name",
"sky_kwargs",
"launcher",
"compute_properties",
],
)
config["autostop_mins"] = self._autostop_mins
config["num_cpus"] = self._num_cpus
config["gpus"] = self._gpus
if self._kube_namespace is not None:
config["kube_namespace"] = self._kube_namespace
if self._kube_context is not None:
config["kube_context"] = self._kube_context
return config
[docs] def endpoint(self, external: bool = False):
if (
not self.ips
or self.on_this_cluster()
or self.cluster_status == ClusterStatus.TERMINATED
):
return None
try:
self.client.check_server()
except ConnectionError:
return None
return super().endpoint(external)
@staticmethod
def relative_yaml_path(yaml_path):
if Path(yaml_path).is_absolute():
yaml_path = "~/.sky/generated/" + Path(yaml_path).name
return yaml_path
def _set_connection_defaults(self):
if not self.server_connection_type:
if self.ssl_keyfile or self.ssl_certfile:
self.server_connection_type = ServerConnectionType.TLS
else:
self.server_connection_type = ServerConnectionType.SSH
if self.server_port is None:
if self.server_connection_type == ServerConnectionType.TLS:
self.server_port = DEFAULT_HTTPS_PORT
elif self.server_connection_type == ServerConnectionType.NONE:
self.server_port = DEFAULT_HTTP_PORT
else:
self.server_port = DEFAULT_SERVER_PORT
if (
self.server_connection_type
in [ServerConnectionType.TLS, ServerConnectionType.NONE]
and self.server_host in LOCAL_HOSTS
):
warnings.warn(
f"Server connection type: {self.server_connection_type}, server host: {self.server_host}. "
f"Note that this will require opening an SSH tunnel to forward traffic from"
f" {self.server_host} to the server."
)
self.open_ports = (
[]
if self.open_ports is None
else [self.open_ports]
if isinstance(self.open_ports, (int, str))
else self.open_ports
)
if self.open_ports:
self.open_ports = [str(p) for p in self.open_ports]
if str(self.server_port) in self.open_ports:
if (
self.server_connection_type
in [ServerConnectionType.TLS, ServerConnectionType.NONE]
and not self.den_auth
):
warnings.warn(
"Server is insecure and must be inside a VPC or have `den_auth` enabled to secure it."
)
else:
warnings.warn(
f"Server port {self.server_port} not included in open ports. Note you are responsible for opening "
f"the port or ensure you have access to it via a VPC."
)
else:
# If using HTTP or HTTPS must enable traffic on the relevant port
if self.server_connection_type in [
ServerConnectionType.TLS,
ServerConnectionType.NONE,
]:
if self.server_port:
warnings.warn(
f"No open ports specified. Setting default port {self.server_port} to open."
)
self.open_ports = [str(self.server_port)]
else:
warnings.warn(
f"No open ports specified. Make sure the relevant port is open. "
f"HTTPS default: {DEFAULT_HTTPS_PORT} and HTTP "
f"default: {DEFAULT_HTTP_PORT}."
)
# ----------------- Launch/Lifecycle Methods -----------------
[docs] def is_up(self) -> bool:
"""Whether the cluster is up.
Example:
>>> rh.ondemand_cluster("rh-cpu").is_up()
"""
from runhouse.resources.hardware.utils import ClusterStatus
if self.on_this_cluster():
return True
# Check sky status without refresh if locally launched
if self.launcher == LauncherType.LOCAL:
self._fetch_sky_status_and_update_cluster_status()
if self.cluster_status == ClusterStatus.TERMINATED:
return False
return self._ping(retry=True)
def _sky_status(self, refresh: bool = True, retry: bool = True):
"""
Get status of Sky cluster.
Return dict looks like:
.. code-block::
{'name': 'sky-cpunode-donny',
'launched_at': 1662317201,
'handle': ResourceHandle(
cluster_name=sky-cpunode-donny,
head_ip=54.211.97.164,
cluster_yaml=/Users/donny/.sky/generated/sky-cpunode-donny.yml,
launched_resources=1x AWS(m6i.2xlarge),
tpu_create_script=None,
tpu_delete_script=None),
'last_use': 'sky cpunode',
'status': <ClusterStatus.UP: 'UP'>,
'autostop': -1,
'metadata': {}}
.. note:: For more information see SkyPilot's :code:`ResourceHandle` `class
<https://github.com/skypilot-org/skypilot/blob/0c2b291b03abe486b521b40a3069195e56b62324/sky/backends/cloud_vm_ray_backend.py#L1457>`__.
"""
if not sky.global_user_state.get_cluster_from_name(self.name):
return None
try:
state = sky.status(cluster_names=[self.name], refresh=refresh)
except rich.errors.LiveError as e:
# We can't have more than one Live display at once, so if we've already launched one (e.g. the first
# time we call status), we can retry without refreshing
if not retry:
raise e
return self._sky_status(refresh=False, retry=False)
# We still need to check if the cluster present in case the cluster went down and was removed from the DB
if len(state) == 0:
return None
return state[0]
def _start_ray_workers(self, ray_port, env_vars):
if not self.internal_ips:
self._update_from_sky_status()
super()._start_ray_workers(ray_port, env_vars)
time.sleep(5)
def _update_cluster_status_from_sky_status(self, sky_status: str):
if sky_status == SkyClusterStatus.UP:
self.cluster_status = ClusterStatus.RUNNING
if sky_status == SkyClusterStatus.STOPPED:
self.cluster_status = ClusterStatus.TERMINATED
if sky_status == SkyClusterStatus.INIT:
self.cluster_status = ClusterStatus.INITIALIZING
def _fetch_sky_status_and_update_cluster_status(self, refresh: bool = False):
cluster_dict = self._sky_status(refresh=refresh)
if not cluster_dict:
self.cluster_status = ClusterStatus.TERMINATED
return
sky_status = cluster_dict["status"]
self._update_cluster_status_from_sky_status(sky_status)
def _populate_connection_from_status_dict(self, cluster_dict: Dict[str, Any]):
if not cluster_dict:
return
sky_status = cluster_dict["status"]
self._update_cluster_status_from_sky_status(sky_status)
if sky_status in [SkyClusterStatus.UP, SkyClusterStatus.INIT]:
handle = cluster_dict["handle"]
head_ip = handle.head_ip
internal_ips, ips = map(list, zip(*handle.stable_internal_external_ips))
if not ips or not head_ip:
raise ValueError(
"Sky's cluster status does not have the necessary information to connect to the cluster. Please check if the cluster is up via `sky status`. Consider bringing down the cluster with `sky down` if you are still having issues."
)
yaml_path = handle.cluster_yaml
if Path(yaml_path).exists():
ssh_values = backend_utils.ssh_credential_from_yaml(
yaml_path, ssh_user=handle.ssh_user
)
if not self.ssh_properties:
self._setup_creds(ssh_values)
launched_resource = handle.launched_resources
cloud = str(launched_resource.cloud).lower()
instance_type = launched_resource.instance_type
region = launched_resource.region
cost_per_hr = launched_resource.get_cost(60 * 60)
disk_size = launched_resource.disk_size
num_cpus = launched_resource.cpus
memory = launched_resource.memory
self.compute_properties = {
"ips": ips,
"internal_ips": internal_ips,
"cloud": cloud,
"instance_type": instance_type,
"region": region,
"cost_per_hour": str(cost_per_hr),
"disk_size": disk_size,
"memory": memory,
"num_cpus": num_cpus,
}
if launched_resource.accelerators:
self.compute_properties["gpus"] = launched_resource.accelerators
if handle.ssh_user:
self.compute_properties["ssh_user"] = handle.ssh_user
if handle.docker_user:
self.compute_properties["docker_user"] = handle.docker_user
if cloud == "kubernetes":
if handle.cached_cluster_info:
self.compute_properties[
"kube_namespace"
] = handle.cached_cluster_info.provider_config.get("namespace")
self.compute_properties[
"kube_context"
] = handle.cached_cluster_info.provider_config.get("context")
instance_infos = list(handle.cached_cluster_info.instances.values())
pod_names_and_ips = {
instance_info[0].internal_ip: instance_info[0].instance_id
for instance_info in instance_infos
}
# Order the pod names to match the order of the IPs
self.compute_properties["pod_names"] = [
pod_names_and_ips[ip] for ip in self.ips
]
if not self.compute_properties.get(
"kube_namespace"
) or not self.compute_properties.get("pod_names"):
import kubernetes
k8s_client = kubernetes.client.CoreV1Api()
pod_names_and_ips = {
pod.status.pod_ip: (pod.metadata.name, pod.metadata.namespace)
for pod in k8s_client.list_pod_for_all_namespaces().items
}
# Order the pod names to match the order of the IPsi
self.compute_properties["pod_names"] = [
pod_names_and_ips[ip][0] for ip in self.ips
]
# Get the namespace for the first pod
self.compute_properties["kube_namespace"] = pod_names_and_ips[
self.head_ip
][1]
if not self.compute_properties.get("kube_context"):
import kubernetes
_, current_context = kubernetes.config.list_kube_config_contexts()
self.compute_properties["kube_context"] = current_context["name"]
self._kube_namespace = self.compute_properties.get("kube_namespace")
self._kube_context = self.compute_properties.get("kube_context")
def _update_from_sky_status(self, dryrun: bool = False):
if self.launcher != LauncherType.LOCAL:
return
# Try to get the cluster status from SkyDB
if self._is_shared:
# If the cluster is shared can ignore, since the sky data will only be saved on the machine where
# the cluster was initially upped
return
cluster_dict = self._sky_status(refresh=not dryrun)
self._populate_connection_from_status_dict(cluster_dict)
def _setup_default_creds(self):
"""Setup the default creds used in launching and for interacting with the cluster once it's up.
For Den launching we load the default ssh creds, and for local launching we let Sky handle it."""
return DenLauncher.load_creds() if self.launcher == LauncherType.DEN else None
[docs] def get_instance_type(self):
"""Returns instance type of the cluster."""
if self.instance_type and "--" in self.instance_type: # K8s specific syntax
return self.instance_type
elif (
self.instance_type
and ":" not in self.instance_type
and "CPU" not in self.instance_type
):
return self.instance_type
return None
def _requested_gpus(self):
"""Returns the gpu type, or None if is a CPU."""
if self._gpus:
return self._gpus
if (
self.instance_type
and ":" in self.instance_type
and "CPU" not in self.instance_type
):
return self.instance_type
return None
def _gpus_per_node(self):
if (
self.is_up()
and self.compute_properties
and self.compute_properties.get("gpus")
):
gpus = self.compute_properties.get("gpus")
else:
gpus = self._requested_gpus()
if gpus:
return int(gpus.split(":")[-1]) if ":" in gpus else 1
return 0
[docs] def num_cpus(self):
"""Return the number of CPUs for a CPU cluster."""
if self._num_cpus:
return self._num_cpus
if (
self.instance_type
and ":" in self.instance_type
and "CPU" in self.instance_type
):
return self.instance_type.rsplit(":", 1)[1]
return None
[docs] async def a_up(self, capture_output: Union[bool, str] = True):
"""Up the cluster async in another process, so it can be parallelized and logs can be captured sanely.
Args:
capture_output (bool): If ``True``, supress the output of the cluster creation process. If ``False``,
print the output normally. If a string, write the output to the file at that path.
"""
with ProcessPoolExecutor() as executor:
loop = asyncio.get_running_loop()
await loop.run_in_executor(
executor, up_cluster_helper, self, capture_output
)
return self
async def a_up_if_not(self, capture_output: Union[bool, str] = True):
if not self.is_up():
await self.a_up(capture_output=capture_output)
return self
[docs] def up(self, verbose: bool = True, force: bool = False, start_server: bool = True):
"""Up the cluster.
Args:
verbose (bool, optional): Whether to stream logs from Den when the cluster is being launched. Only
relevant if launching via Den. (Default: `True`)
force (bool, optional): Whether to launch the cluster even if one with the same configs already exists.
Only relevant if launching via Den. (Default: `False`)
Example:
>>> rh.ondemand_cluster("rh-cpu").up()
"""
if self.on_this_cluster():
return self
if self._is_shared:
logger.warning(
"Cannot up a shared cluster. Only cluster owners can perform this operation."
)
return self
if self.launcher == LauncherType.DEN:
logger.info("Launching cluster with Den")
DenLauncher.up(cluster=self, verbose=verbose, force=force)
elif self.launcher == LauncherType.LOCAL:
logger.info("Provisioning cluster")
LocalLauncher.up(cluster=self, verbose=verbose)
if start_server:
logger.info("Starting Runhouse server on cluster")
self.start_server()
pprint_launched_cluster_summary(cluster=self)
return self
[docs] def keep_warm(self, mins: int = -1):
"""Keep the cluster warm for given number of minutes after inactivity.
Args:
mins (int): Amount of time (in min) to keep the cluster warm after inactivity.
If set to -1, keep cluster warm indefinitely. (Default: `-1`)
"""
self.autostop_mins = mins
return self
[docs] def teardown(self, verbose: bool = True):
"""Teardown cluster.
Args:
verbose (bool, optional): Whether to stream logs from Den when the cluster is being downed. Only relevant
when tearing down via Den. (Default: `True`)
Example:
>>> rh.ondemand_cluster("rh-cpu").teardown()
"""
if self.launcher == LauncherType.DEN:
logger.info("Tearing down cluster with Den.")
DenLauncher.teardown(cluster=self, verbose=verbose)
else:
logger.info("Tearing down cluster locally via Sky.")
LocalLauncher.teardown(cluster=self, verbose=verbose)
if self.rns_address is not None:
try:
# Update Den with the terminated status
status_data = {
"daemon_status": RunhouseDaemonStatus.TERMINATED,
"resource_type": self.__class__.__base__.__name__.lower(),
"data": {},
}
cluster_uri = rns_client.format_rns_address(self.rns_address)
status_resp = requests.post(
f"{rns_client.api_server_url}/resource/{cluster_uri}/cluster/status",
json=status_data,
headers=rns_client.request_headers(),
)
# Note: 404 means that the cluster is not saved in Den
if status_resp.status_code not in [200, 404]:
logger.warning(
"Failed to update Den with terminated cluster status"
)
except Exception as e:
logger.warning(e)
[docs] def teardown_and_delete(self, verbose: bool = True):
"""Teardown cluster and delete it from configs.
Args:
verbose (bool, optional): Whether to stream logs from Den when the cluster is being downed. Only relevant
when tearing down via Den. (Default: `True`)
Example:
>>> rh.ondemand_cluster("rh-cpu").teardown_and_delete()
"""
self.teardown(verbose)
rns_client.delete_configs(resource=self)
[docs] @contextlib.contextmanager
def pause_autostop(self):
"""Context manager to temporarily pause autostop.
Example:
>>> with rh.ondemand_cluster.pause_autostop():
>>> rh.ondemand_cluster.run_bash(["python train.py"])
"""
self.run_bash_over_ssh(_cluster_set_autostop_command(-1), node=self.head_ip)
yield
self.run_bash_over_ssh_(
_cluster_set_autostop_command(self._autostop_mins), node=self.head_ip
)
# ----------------- SSH Methods ----------------- #
[docs] @staticmethod
def cluster_ssh_key(path_to_file: Path):
"""Retrieve SSH key for the cluster.
Args:
path_to_file (Path): Path of the private key associated with the cluster.
Example:
>>> ssh_priv_key = rh.ondemand_cluster("rh-cpu").cluster_ssh_key("~/.ssh/id_rsa")
"""
try:
f = open(path_to_file, "r")
private_key = f.read()
return private_key
except FileNotFoundError:
raise Exception(f"File with ssh key not found in: {path_to_file}")
[docs] def ssh(self, node: str = None):
"""SSH into the cluster.
Args:
node: Node to SSH into. If no node is specified, will SSH onto the head node.
(Default: ``None``)
Example:
>>> rh.ondemand_cluster("rh-cpu").ssh()
>>> rh.ondemand_cluster("rh-cpu", node="3.89.174.234").ssh()
"""
if self.provider == "kubernetes":
namespace_flag = (
f"-n {self._kube_namespace}" if self._kube_namespace else ""
)
command = f"kubectl get pods {namespace_flag} | grep {self.name}"
try:
output = subprocess.check_output(command, shell=True, text=True)
lines = output.strip().split("\n")
if lines:
pod_name = lines[0].split()[0]
else:
logger.info("No matching pods found.")
except subprocess.CalledProcessError as e:
raise Exception(f"Error: {e}")
cmd = f"kubectl exec -it {pod_name} {namespace_flag} -- /bin/bash"
subprocess.run(cmd, shell=True, check=True)
else:
# If SSHing onto a specific node, which requires an SSH public key for verification
# Note: the SSH key must either be the one used for launch, or the user's default SSH public key
# if the cluster is shared
from runhouse.resources.hardware.sky_command_runner import SshMode
if self._is_shared:
# If the cluster is shared auth will be based on the user's default SSH key
default_ssh_key = configs.get("default_ssh_key")
if default_ssh_key is None:
raise ValueError("No default SSH key found local Runhouse config")
else:
ssh_private_key_path = self.ssh_properties.get("ssh_private_key")
if (
ssh_private_key_path is None
or not Path(ssh_private_key_path).expanduser().exists()
):
# SSH keys used for launching must be present if it's not a shared cluster
raise FileNotFoundError(
f"Expected SSH key in path: {ssh_private_key_path}"
)
runner = self._command_runner(node=node)
if self.docker_user:
cmd = runner.run(
cmd="bash --rcfile <(echo '. ~/.bashrc; conda deactivate')",
ssh_mode=SshMode.INTERACTIVE,
port_forward=None,
return_cmd=True,
)
subprocess.run(cmd, shell=True)
else:
subprocess.run(
runner._ssh_base_command(
ssh_mode=SshMode.INTERACTIVE, port_forward=None
)
)
def _ping(self, timeout=5, retry=False):
if super()._ping(timeout=timeout, retry=False):
return True
if retry:
self._update_from_sky_status(dryrun=False)
return super()._ping(timeout=timeout, retry=False)
return False