import asyncio
import contextlib
import subprocess
import time
import warnings
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
from typing import Any, Dict, List, Union
import requests
import rich.errors
try:
import sky
from sky import ClusterStatus as SkyClusterStatus
from sky.backends import backend_utils
except ImportError:
pass
from runhouse.constants import (
DEFAULT_HTTP_PORT,
DEFAULT_HTTPS_PORT,
DEFAULT_SERVER_PORT,
LOCAL_HOSTS,
)
from runhouse.globals import configs, obj_store, rns_client
from runhouse.logger import get_logger
from runhouse.resources.hardware.utils import (
_cluster_set_autostop_command,
ClusterStatus,
LauncherType,
RunhouseDaemonStatus,
ServerConnectionType,
up_cluster_helper,
)
from .cluster import Cluster
from .launcher_utils import DenLauncher, LocalLauncher
logger = get_logger(__name__)
[docs]class OnDemandCluster(Cluster):
RESOURCE_TYPE = "cluster"
RECONNECT_TIMEOUT = 5
DEFAULT_KEYFILE = "~/.ssh/sky-key"
[docs] def __init__(
self,
name,
instance_type: str = None,
num_nodes: int = None,
provider: str = None,
default_env: "Env" = None,
dryrun: bool = False,
autostop_mins: int = None,
use_spot: bool = False,
memory: Union[int, str] = None,
disk_size: Union[int, str] = None,
num_cpus: Union[int, str] = None,
accelerators: str = None,
open_ports: Union[int, str, List[int]] = None,
server_host: int = None,
server_port: int = None,
server_connection_type: str = None,
launcher: str = None,
ssl_keyfile: str = None,
ssl_certfile: str = None,
domain: str = None,
den_auth: bool = False,
region: str = None,
sky_kwargs: Dict = None,
**kwargs, # We have this here to ignore extra arguments when calling from from_config
):
"""
On-demand `SkyPilot <https://github.com/skypilot-org/skypilot/>`__ Cluster.
.. note::
To build a cluster, please use the factory method :func:`cluster`.
"""
cluster_launcher = launcher or configs.launcher
skip_creds = cluster_launcher == LauncherType.DEN
super().__init__(
name=name,
default_env=default_env,
server_host=server_host,
server_port=server_port,
server_connection_type=server_connection_type,
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile,
domain=domain,
den_auth=den_auth,
dryrun=dryrun,
skip_creds=skip_creds,
**kwargs,
)
if "num_instances" in kwargs and not num_nodes:
# Handle BC for configs previously saved with `num_instances`
num_nodes = kwargs.get("num_instances")
self.instance_type = instance_type
self.num_nodes = num_nodes
self.provider = provider or configs.get("default_provider")
self._autostop_mins = (
autostop_mins
if autostop_mins is not None
else configs.get("default_autostop")
)
self.open_ports = open_ports
self.use_spot = use_spot if use_spot is not None else configs.get("use_spot")
self.region = region
self.memory = memory
self.disk_size = disk_size
self._num_cpus = num_cpus
self._accelerators = accelerators
self.sky_kwargs = sky_kwargs or {}
self.launcher = cluster_launcher
self.compute_properties = {}
# backwards compatibility
if kwargs.get("stable_internal_external_ips"):
internal_ips, ips = map(
list, zip(*kwargs.get("stable_internal_external_ips"))
)
self.compute_properties["ips"] = ips
self.compute_properties["internal_ips"] = internal_ips
elif kwargs.get("ips"):
self.compute_properties["ips"] = kwargs.get("ips")
self.compute_properties = {
**self.compute_properties,
**kwargs.get("compute_properties", {}),
**kwargs.get("launched_properties", {}),
}
self._docker_user = None
self._namespace = kwargs.get("namespace")
self._context = kwargs.get("context")
self._cluster_status = kwargs.get("cluster_status")
# Checks if state info is in local sky db, populates if so.
if not dryrun and not self.ips and not self.creds_values:
# Cluster status is set to INIT in the Sky DB right after starting, so we need to refresh once
self._update_from_sky_status(dryrun=True)
@property
def ips(self):
return self.compute_properties.get("ips", [])
@property
def internal_ips(self):
return self.compute_properties.get("internal_ips", [])
@property
def client(self):
try:
return super().client
except ValueError as e:
if not self.ips:
# Try loading in from local Sky DB
self._update_from_sky_status(dryrun=True)
if not self.ips:
raise ValueError(
f"Could not determine ips for ondemand cluster <{self.name}>. "
"Up the cluster with `cluster.up_if_not`."
)
return super().client
raise e
@property
def autostop_mins(self):
return self._autostop_mins
@autostop_mins.setter
def autostop_mins(self, mins):
self._autostop_mins = mins
if not self.is_up():
return
if self.on_this_cluster():
obj_store.set_cluster_config_value("autostop_mins", mins)
else:
self.call_client_method("set_settings", {"autostop_mins": mins})
if self.launcher == "local":
LocalLauncher.keep_warm(self, mins)
elif self.launcher == "den":
DenLauncher.keep_warm(self, mins)
@property
def image_id(self) -> str:
if self.image and self.image.image_id:
return self.image.image_id
return None
@property
def docker_user(self) -> str:
if self._docker_user:
return self._docker_user
# TODO detect whether this is a k8s cluster properly, and handle the user setting / SSH properly
# (e.g. SkyPilot's new KubernetesCommandRunner)
if not self.image_id or "docker:" not in self.image_id:
return None
if self.compute_properties.get("cloud") == "kubernetes":
return self.compute_properties.get(
"docker_user", self.compute_properties.get("ssh_user", "root")
)
from runhouse.resources.hardware.sky_command_runner import get_docker_user
if not self._creds:
return
self._docker_user = get_docker_user(self, self.creds_values)
return self._docker_user
def config(self, condensed=True):
config = super().config(condensed)
self.save_attrs_to_config(
config,
[
"instance_type",
"num_nodes",
"provider",
"open_ports",
"use_spot",
"region",
"memory",
"disk_size",
"sky_kwargs",
"launcher",
"compute_properties",
],
)
config["autostop_mins"] = self._autostop_mins
config["num_cpus"] = self._num_cpus
config["accelerators"] = self._accelerators
if self._namespace is not None:
config["namespace"] = self._namespace
if self._context is not None:
config["context"] = self._context
return config
[docs] def endpoint(self, external: bool = False):
if (
not self.ips
or self.on_this_cluster()
or self._cluster_status == ClusterStatus.TERMINATED
):
return None
try:
self.client.check_server()
except ConnectionError:
return None
return super().endpoint(external)
@staticmethod
def relative_yaml_path(yaml_path):
if Path(yaml_path).is_absolute():
yaml_path = "~/.sky/generated/" + Path(yaml_path).name
return yaml_path
def set_connection_defaults(self):
if not self.server_connection_type:
if self.ssl_keyfile or self.ssl_certfile:
self.server_connection_type = ServerConnectionType.TLS
else:
self.server_connection_type = ServerConnectionType.SSH
if self.server_port is None:
if self.server_connection_type == ServerConnectionType.TLS:
self.server_port = DEFAULT_HTTPS_PORT
elif self.server_connection_type == ServerConnectionType.NONE:
self.server_port = DEFAULT_HTTP_PORT
else:
self.server_port = DEFAULT_SERVER_PORT
if (
self.server_connection_type
in [ServerConnectionType.TLS, ServerConnectionType.NONE]
and self.server_host in LOCAL_HOSTS
):
warnings.warn(
f"Server connection type: {self.server_connection_type}, server host: {self.server_host}. "
f"Note that this will require opening an SSH tunnel to forward traffic from"
f" {self.server_host} to the server."
)
self.open_ports = (
[]
if self.open_ports is None
else [self.open_ports]
if isinstance(self.open_ports, (int, str))
else self.open_ports
)
if self.open_ports:
self.open_ports = [str(p) for p in self.open_ports]
if str(self.server_port) in self.open_ports:
if (
self.server_connection_type
in [ServerConnectionType.TLS, ServerConnectionType.NONE]
and not self.den_auth
):
warnings.warn(
"Server is insecure and must be inside a VPC or have `den_auth` enabled to secure it."
)
else:
warnings.warn(
f"Server port {self.server_port} not included in open ports. Note you are responsible for opening "
f"the port or ensure you have access to it via a VPC."
)
else:
# If using HTTP or HTTPS must enable traffic on the relevant port
if self.server_connection_type in [
ServerConnectionType.TLS,
ServerConnectionType.NONE,
]:
if self.server_port:
warnings.warn(
f"No open ports specified. Setting default port {self.server_port} to open."
)
self.open_ports = [str(self.server_port)]
else:
warnings.warn(
f"No open ports specified. Make sure the relevant port is open. "
f"HTTPS default: {DEFAULT_HTTPS_PORT} and HTTP "
f"default: {DEFAULT_HTTP_PORT}."
)
# ----------------- Launch/Lifecycle Methods -----------------
[docs] def is_up(self) -> bool:
"""Whether the cluster is up.
Example:
>>> rh.ondemand_cluster("rh-cpu").is_up()
"""
from runhouse.resources.hardware.utils import ClusterStatus
if self.on_this_cluster():
return True
# Check sky status without refresh if locally launched
if self.launcher == LauncherType.LOCAL:
self._fetch_sky_status_and_update_cluster_status()
if self._cluster_status == ClusterStatus.TERMINATED:
return False
return self._ping(retry=True)
def _sky_status(self, refresh: bool = True, retry: bool = True):
"""
Get status of Sky cluster.
Return dict looks like:
.. code-block::
{'name': 'sky-cpunode-donny',
'launched_at': 1662317201,
'handle': ResourceHandle(
cluster_name=sky-cpunode-donny,
head_ip=54.211.97.164,
cluster_yaml=/Users/donny/.sky/generated/sky-cpunode-donny.yml,
launched_resources=1x AWS(m6i.2xlarge),
tpu_create_script=None,
tpu_delete_script=None),
'last_use': 'sky cpunode',
'status': <ClusterStatus.UP: 'UP'>,
'autostop': -1,
'metadata': {}}
.. note:: For more information see SkyPilot's :code:`ResourceHandle` `class
<https://github.com/skypilot-org/skypilot/blob/0c2b291b03abe486b521b40a3069195e56b62324/sky/backends/cloud_vm_ray_backend.py#L1457>`__.
"""
if not sky.global_user_state.get_cluster_from_name(self.name):
return None
try:
state = sky.status(cluster_names=[self.name], refresh=refresh)
except rich.errors.LiveError as e:
# We can't have more than one Live display at once, so if we've already launched one (e.g. the first
# time we call status), we can retry without refreshing
if not retry:
raise e
return self._sky_status(refresh=False, retry=False)
# We still need to check if the cluster present in case the cluster went down and was removed from the DB
if len(state) == 0:
return None
return state[0]
def _start_ray_workers(self, ray_port, env_vars):
if not self.internal_ips:
self._update_from_sky_status()
super()._start_ray_workers(ray_port, env_vars)
time.sleep(5)
def _update_cluster_status_from_sky_status(self, sky_status: str):
if sky_status == SkyClusterStatus.UP:
self._cluster_status = ClusterStatus.RUNNING
if sky_status == SkyClusterStatus.STOPPED:
self._cluster_status = ClusterStatus.TERMINATED
if sky_status == SkyClusterStatus.INIT:
self._cluster_status = ClusterStatus.INITIALIZING
def _fetch_sky_status_and_update_cluster_status(self):
cluster_dict = self._sky_status(refresh=False)
if not cluster_dict:
self._cluster_status = ClusterStatus.TERMINATED
return
sky_status = cluster_dict["status"]
self._update_cluster_status_from_sky_status(sky_status)
def _populate_connection_from_status_dict(self, cluster_dict: Dict[str, Any]):
if not cluster_dict:
return
sky_status = cluster_dict["status"]
self._update_cluster_status_from_sky_status(sky_status)
if sky_status in [SkyClusterStatus.UP, SkyClusterStatus.INIT]:
handle = cluster_dict["handle"]
head_ip = handle.head_ip
internal_ips, ips = map(list, zip(*handle.stable_internal_external_ips))
if not ips or not head_ip:
raise ValueError(
"Sky's cluster status does not have the necessary information to connect to the cluster. Please check if the cluster is up via `sky status`. Consider bringing down the cluster with `sky down` if you are still having issues."
)
yaml_path = handle.cluster_yaml
if Path(yaml_path).exists():
ssh_values = backend_utils.ssh_credential_from_yaml(
yaml_path, ssh_user=handle.ssh_user
)
if not self.creds_values or not self.ssh_properties:
self._setup_creds(ssh_values)
launched_resource = handle.launched_resources
cloud = str(launched_resource.cloud).lower()
instance_type = launched_resource.instance_type
region = launched_resource.region
cost_per_hr = launched_resource.get_cost(60 * 60)
disk_size = launched_resource.disk_size
num_cpus = launched_resource.cpus
memory = launched_resource.memory
accelerators = launched_resource.accelerators
self.compute_properties = {
"ips": ips,
"internal_ips": internal_ips,
"cloud": cloud,
"instance_type": instance_type,
"region": region,
"cost_per_hour": str(cost_per_hr),
"disk_size": disk_size,
"memory": memory,
"accelerators": accelerators,
"num_cpus": num_cpus,
}
if launched_resource.accelerators:
self.compute_properties["accelerators"] = launched_resource.accelerators
if handle.ssh_user:
self.compute_properties["ssh_user"] = handle.ssh_user
if handle.docker_user:
self.compute_properties["docker_user"] = handle.docker_user
if cloud == "kubernetes":
if handle.cached_cluster_info:
self.compute_properties[
"namespace"
] = handle.cached_cluster_info.provider_config.get("namespace")
self.compute_properties[
"context"
] = handle.cached_cluster_info.provider_config.get("context")
instance_infos = list(handle.cached_cluster_info.instances.values())
pod_names_and_ips = {
instance_info[0].internal_ip: instance_info[0].instance_id
for instance_info in instance_infos
}
# Order the pod names to match the order of the IPs
self.compute_properties["pod_names"] = [
pod_names_and_ips[ip] for ip in self.ips
]
if not self.compute_properties.get(
"namespace"
) or not self.compute_properties.get("pod_names"):
import kubernetes
k8s_client = kubernetes.client.CoreV1Api()
pod_names_and_ips = {
pod.status.pod_ip: (pod.metadata.name, pod.metadata.namespace)
for pod in k8s_client.list_pod_for_all_namespaces().items
}
# Order the pod names to match the order of the IPsi
self.compute_properties["pod_names"] = [
pod_names_and_ips[ip][0] for ip in self.ips
]
# Get the namespace for the first pod
self.compute_properties["namespace"] = pod_names_and_ips[
self.head_ip
][1]
if not self.compute_properties.get("context"):
import kubernetes
_, current_context = kubernetes.config.list_kube_config_contexts()
self.compute_properties["context"] = current_context["name"]
def _update_from_sky_status(self, dryrun: bool = False):
# Try to get the cluster status from SkyDB
if self.is_shared:
# If the cluster is shared can ignore, since the sky data will only be saved on the machine where
# the cluster was initially upped
return
if self.launcher == "local":
cluster_dict = self._sky_status(refresh=not dryrun)
self._populate_connection_from_status_dict(cluster_dict)
[docs] def get_instance_type(self):
"""Returns instance type of the cluster."""
if self.instance_type and "--" in self.instance_type: # K8s specific syntax
return self.instance_type
elif (
self.instance_type
and ":" not in self.instance_type
and "CPU" not in self.instance_type
):
return self.instance_type
return None
[docs] def accelerators(self):
"""Returns the acclerator type, or None if is a CPU."""
if self._accelerators:
return self._accelerators
if (
self.instance_type
and ":" in self.instance_type
and "CPU" not in self.instance_type
):
return self.instance_type
return None
[docs] def num_cpus(self):
"""Return the number of CPUs for a CPU cluster."""
if self._num_cpus:
return self._num_cpus
if (
self.instance_type
and ":" in self.instance_type
and "CPU" in self.instance_type
):
return self.instance_type.rsplit(":", 1)[1]
return None
[docs] async def a_up(self, capture_output: Union[bool, str] = True):
"""Up the cluster async in another process, so it can be parallelized and logs can be captured sanely.
capture_output: If True, supress the output of the cluster creation process. If False, print the output
normally. If a string, write the output to the file at that path.
"""
with ProcessPoolExecutor() as executor:
loop = asyncio.get_running_loop()
await loop.run_in_executor(
executor, up_cluster_helper, self, capture_output
)
return self
async def a_up_if_not(self, capture_output: Union[bool, str] = True):
if not self.is_up():
await self.a_up(capture_output=capture_output)
return self
[docs] def up(self, verbose: bool = True, force: bool = False):
"""Up the cluster.
Args:
verbose (bool, optional): Whether to stream logs from Den when the cluster is being launched. Only
relevant if launching via Den. (Default: `True`)
force (bool, optional): Whether to launch the cluster even if one with the same configs already exists.
Only relevant if launching via Den. (Default: `False`)
Example:
>>> rh.ondemand_cluster("rh-cpu").up()
"""
if self.on_this_cluster():
return self
if self.launcher == LauncherType.DEN:
logger.info("Launching cluster with Den")
DenLauncher.up(cluster=self, verbose=verbose, force=force)
elif self.launcher == LauncherType.LOCAL:
logger.info("Provisioning cluster")
LocalLauncher.up(cluster=self, verbose=verbose)
return self
[docs] def keep_warm(self, mins: int = -1):
"""Keep the cluster warm for given number of minutes after inactivity.
Args:
mins (int): Amount of time (in min) to keep the cluster warm after inactivity.
If set to -1, keep cluster warm indefinitely. (Default: `-1`)
"""
self.autostop_mins = mins
return self
[docs] def teardown(self, verbose: bool = True):
"""Teardown cluster.
Args:
verbose (bool, optional): Whether to stream logs from Den when the cluster is being downed. Only relevant
when tearing down via Den. (Default: `True`)
Example:
>>> rh.ondemand_cluster("rh-cpu").teardown()
"""
if self.launcher == LauncherType.DEN:
logger.info("Tearing down cluster with Den.")
DenLauncher.teardown(cluster=self, verbose=verbose)
else:
logger.info("Tearing down cluster locally via Sky.")
LocalLauncher.teardown(cluster=self, verbose=verbose)
if self.rns_address is not None:
try:
# Update Den with the terminated status
status_data = {
"daemon_status": RunhouseDaemonStatus.TERMINATED,
"resource_type": self.__class__.__base__.__name__.lower(),
"data": {},
}
cluster_uri = rns_client.format_rns_address(self.rns_address)
status_resp = requests.post(
f"{rns_client.api_server_url}/resource/{cluster_uri}/cluster/status",
json=status_data,
headers=rns_client.request_headers(),
)
# Note: 404 means that the cluster is not saved in Den
if status_resp.status_code not in [200, 404]:
logger.warning(
"Failed to update Den with terminated cluster status"
)
except Exception as e:
logger.warning(e)
[docs] def teardown_and_delete(self, verbose: bool = True):
"""Teardown cluster and delete it from configs.
Args:
verbose (bool, optional): Whether to stream logs from Den when the cluster is being downed. Only relevant
when tearing down via Den. (Default: `True`)
Example:
>>> rh.ondemand_cluster("rh-cpu").teardown_and_delete()
"""
self.teardown(verbose)
rns_client.delete_configs(resource=self)
[docs] @contextlib.contextmanager
def pause_autostop(self):
"""Context manager to temporarily pause autostop.
Example:
>>> with rh.ondemand_cluster.pause_autostop():
>>> rh.ondemand_cluster.run(["python train.py"])
"""
self.run_bash_over_ssh(_cluster_set_autostop_command(-1), node=self.head_ip)
yield
self.run_bash_over_ssh_(
_cluster_set_autostop_command(self._autostop_mins), node=self.head_ip
)
# ----------------- SSH Methods ----------------- #
[docs] @staticmethod
def cluster_ssh_key(path_to_file: Path):
"""Retrieve SSH key for the cluster.
Args:
path_to_file (Path): Path of the private key associated with the cluster.
Example:
>>> ssh_priv_key = rh.ondemand_cluster("rh-cpu").cluster_ssh_key("~/.ssh/id_rsa")
"""
try:
f = open(path_to_file, "r")
private_key = f.read()
return private_key
except FileNotFoundError:
raise Exception(f"File with ssh key not found in: {path_to_file}")
[docs] def ssh(self, node: str = None):
"""SSH into the cluster.
Args:
node: Node to SSH into. If no node is specified, will SSH onto the head node.
(Default: ``None``)
Example:
>>> rh.ondemand_cluster("rh-cpu").ssh()
>>> rh.ondemand_cluster("rh-cpu", node="3.89.174.234").ssh()
"""
if self.provider == "kubernetes":
namespace_flag = f"-n {self._namespace}" if self._namespace else ""
command = f"kubectl get pods {namespace_flag} | grep {self.name}"
try:
output = subprocess.check_output(command, shell=True, text=True)
lines = output.strip().split("\n")
if lines:
pod_name = lines[0].split()[0]
else:
logger.info("No matching pods found.")
except subprocess.CalledProcessError as e:
raise Exception(f"Error: {e}")
cmd = f"kubectl exec -it {pod_name} {namespace_flag} -- /bin/bash"
subprocess.run(cmd, shell=True, check=True)
else:
# If SSHing onto a specific node, which requires the default sky public key for verification
from runhouse.resources.hardware.sky_command_runner import SshMode
sky_key = Path(
self.creds_values.get("ssh_private_key", self.DEFAULT_KEYFILE)
).expanduser()
if not sky_key.exists():
raise FileNotFoundError(f"Expected default sky key in path: {sky_key}")
runner = self._command_runner(node=node)
if self.docker_user:
cmd = runner.run(
cmd="bash --rcfile <(echo '. ~/.bashrc; conda deactivate')",
ssh_mode=SshMode.INTERACTIVE,
port_forward=None,
return_cmd=True,
)
subprocess.run(cmd, shell=True)
else:
subprocess.run(
runner._ssh_base_command(
ssh_mode=SshMode.INTERACTIVE, port_forward=None
)
)
def _ping(self, timeout=5, retry=False):
if super()._ping(timeout=timeout, retry=False):
return True
if retry:
self._update_from_sky_status(dryrun=False)
return super()._ping(timeout=timeout, retry=False)
return False