You are viewing v0.0.12 version. Click here to see docs for the latest stable version.

Source code for runhouse.resources.hardware.on_demand_cluster

import contextlib
import logging
import subprocess
from pathlib import Path
from typing import Any, Dict

import sky
import yaml
from sky.backends import backend_utils, CloudVmRayBackend

from runhouse.globals import configs, rns_client

from .cluster import Cluster
from .utils import _current_cluster

logger = logging.getLogger(__name__)


[docs]class OnDemandCluster(Cluster): RESOURCE_TYPE = "cluster" RECONNECT_TIMEOUT = 5
[docs] def __init__( self, name, instance_type: str = None, num_instances: int = None, provider: str = None, dryrun=False, autostop_mins=None, use_spot=False, image_id=None, region=None, sky_state=None, **kwargs, # We have this here to ignore extra arguments when calling from from_config ): """ On-demand `SkyPilot <https://github.com/skypilot-org/skypilot/>`_ Cluster. .. note:: To build a cluster, please use the factory method :func:`cluster`. """ super().__init__(name=name, dryrun=dryrun) self.instance_type = instance_type self.num_instances = num_instances self.provider = provider or configs.get("default_provider") self.autostop_mins = ( autostop_mins if autostop_mins is not None else configs.get("default_autostop") ) self.use_spot = use_spot if use_spot is not None else configs.get("use_spot") self.image_id = image_id self.region = region self.address = None self.client = None self.sky_state = sky_state # Checks if state info is in local sky db, populates if so. status_dict = self.status(refresh=False) if status_dict: self._populate_connection_from_status_dict(status_dict) elif self.sky_state: self._save_sky_state() if not self.address and not dryrun: # Cluster status is set to INIT in the Sky DB right after starting, so we need to refresh once self._update_from_sky_status(dryrun=False)
@staticmethod def from_config(config: dict, dryrun=False): return OnDemandCluster(**config, dryrun=dryrun) @property def config_for_rns(self): config = super().config_for_rns # Also store the ssh keys for the cluster in RNS config.update( { "instance_type": self.instance_type, "num_instances": self.num_instances, "provider": self.provider, "autostop_mins": self.autostop_mins, "use_spot": self.use_spot, "image_id": self.image_id, "region": self.region, "sky_state": self._get_sky_state(), } ) return config def _get_sky_state(self): config = sky.global_user_state.get_cluster_from_name(self.name) if not config: return None config["status"] = config[ "status" ].name # ClusterStatus enum is not json serializable if config["handle"]: # with open(config["handle"].cluster_yaml, mode="r") as f: # config["ray_config"] = yaml.safe_load(f) config["public_key"] = self.ssh_creds()["ssh_private_key"] + ".pub" config["handle"] = { "cluster_name": config["handle"].cluster_name, # This is saved as an absolute path - convert it to relative "cluster_yaml": self.relative_yaml_path( yaml_path=config["handle"]._cluster_yaml ), "head_ip": config["handle"].head_ip or self.address, "stable_internal_external_ips": config[ "handle" ].stable_internal_external_ips, "launched_nodes": config["handle"].launched_nodes, "launched_resources": config[ "handle" ].launched_resources.to_yaml_config(), } config["handle"]["launched_resources"].pop("spot_recovery", None) config["ssh_creds"] = self.ssh_creds() return config def _copy_sky_yaml_from_cluster(self, abs_yaml_path: str): if not Path(abs_yaml_path).exists(): Path(abs_yaml_path).parent.mkdir(parents=True, exist_ok=True) self._rsync("~/.sky/sky_ray.yml", abs_yaml_path, up=False) # Save SSH info to the ~/.ssh/config ray_yaml = yaml.safe_load(open(abs_yaml_path, "r")) backend_utils.SSHConfigHelper.add_cluster( self.name, [self.address], ray_yaml["auth"] ) def _save_sky_state(self): if not self.sky_state: raise ValueError("No sky state to save") # if we're on this cluster, no need to save sky state current_cluster_name = _current_cluster("cluster_name") if self.sky_state.get("handle", {}).get("cluster_name") == current_cluster_name: return handle_info = self.sky_state.get("handle", {}) # If we already have the cluster in local sky db, # we don't need to save the state, just populate the connection info from the status if not sky.global_user_state.get_cluster_from_name(self.name): # Try running a command on the cluster before saving down the state into sky db self.address = handle_info.get("head_ip") self._ssh_creds = self.sky_state["ssh_creds"] try: self._ping(timeout=self.RECONNECT_TIMEOUT) except TimeoutError: self.address = None self._ssh_creds = None print( f"Timeout when trying to connect to cluster {self.name}, treating cluster as down." ) return resources = sky.Resources.from_yaml_config( handle_info["launched_resources"] ) # Need to convert to relative to find the yaml file in a new environment yaml_path = self.relative_yaml_path(handle_info.get("cluster_yaml")) handle = CloudVmRayBackend.ResourceHandle( cluster_name=self.name, cluster_yaml=str(Path(yaml_path).expanduser()), launched_nodes=handle_info["launched_nodes"], stable_internal_external_ips=handle_info.get( "stable_internal_external_ips" ) or [(handle_info["head_ip"], handle_info["head_ip"])], launched_resources=resources, ) sky.global_user_state.add_or_update_cluster( cluster_name=self.name, cluster_handle=handle, requested_resources=[resources], is_launch=True, ready=False, ) # Now try loading in the status from the sky DB status = self.status(refresh=False) abs_yaml_path = status["handle"].cluster_yaml try: if not Path(abs_yaml_path).exists(): # This is also a good way to check if the cluster is still up self._copy_sky_yaml_from_cluster(abs_yaml_path) else: # We still should check if the cluster is up, since the status/yaml file could be stale self._ping(timeout=self.RECONNECT_TIMEOUT) except Exception: # Refresh the cluster status before saving the ssh info so SkyPilot has a chance to wipe the .ssh/config if # the cluster went down self._update_from_sky_status(dryrun=self.dryrun) def __getstate__(self): """Make sure sky_state is loaded in before pickling.""" self.sky_state = self._get_sky_state() return super().__getstate__() @staticmethod def relative_yaml_path(yaml_path): if Path(yaml_path).is_absolute(): yaml_path = "~/.sky/generated/" + Path(yaml_path).name return yaml_path # ----------------- Launch/Lifecycle Methods -----------------
[docs] def is_up(self) -> bool: """Whether the cluster is up. Example: >>> rh.ondemand_cluster("rh-cpu").is_up() """ self._update_from_sky_status(dryrun=False) return self.address is not None
[docs] def status(self, refresh: bool = True): """ Get status of Sky cluster. Return dict looks like: .. code-block:: {'name': 'sky-cpunode-donny', 'launched_at': 1662317201, 'handle': ResourceHandle( cluster_name=sky-cpunode-donny, head_ip=54.211.97.164, cluster_yaml=/Users/donny/.sky/generated/sky-cpunode-donny.yml, launched_resources=1x AWS(m6i.2xlarge), tpu_create_script=None, tpu_delete_script=None), 'last_use': 'sky cpunode', 'status': <ClusterStatus.UP: 'UP'>, 'autostop': -1, 'metadata': {}} .. note:: For more information see SkyPilot's :code:`ResourceHandle` `class <https://github.com/skypilot-org/skypilot/blob/0c2b291b03abe486b521b40a3069195e56b62324/sky/backends/cloud_vm_ray_backend.py#L1457>`_. Example: >>> status = rh.ondemand_cluster("rh-cpu").status() """ # noqa # return backend_utils._refresh_cluster_record( # self.name, force_refresh=refresh, acquire_per_cluster_status_lock=False # ) if not sky.global_user_state.get_cluster_from_name(self.name): return None state = sky.status(cluster_names=[self.name], refresh=refresh) # We still need to check if the cluster present in case the cluster went down and was removed from the DB if len(state) == 0: return None return state[0]
def _populate_connection_from_status_dict(self, cluster_dict: Dict[str, Any]): if cluster_dict and cluster_dict["status"].name in ["UP", "INIT"]: self.address = cluster_dict["handle"].head_ip yaml_path = cluster_dict["handle"].cluster_yaml if Path(yaml_path).exists(): self._ssh_creds = backend_utils.ssh_credential_from_yaml(yaml_path) else: self.address = None self._ssh_creds = None def _update_from_sky_status(self, dryrun: bool = False): # Try to get the cluster status from SkyDB cluster_dict = self.status(refresh=not dryrun) self._populate_connection_from_status_dict(cluster_dict)
[docs] def up(self): """Up the cluster. Example: >>> rh.ondemand_cluster("rh-cpu").up() """ if self.on_this_cluster(): return self if self.provider in ["aws", "gcp", "azure", "lambda", "cheapest"]: task = sky.Task( num_nodes=self.num_instances if self.instance_type and ":" not in self.instance_type else None, # docker_image=image, # Zongheng: this is experimental, don't use it # envs=None, ) cloud_provider = ( sky.clouds.CLOUD_REGISTRY.from_str(self.provider) if self.provider != "cheapest" else None ) task.set_resources( sky.Resources( cloud=cloud_provider, instance_type=self.instance_type if self.instance_type and ":" not in self.instance_type and "CPU" not in self.instance_type else None, accelerators=self.instance_type if self.instance_type and ":" in self.instance_type and "CPU" not in self.instance_type else None, cpus=self.instance_type.rsplit(":", 1)[1] if self.instance_type and ":" in self.instance_type and "CPU" in self.instance_type else None, region=self.region, image_id=self.image_id, use_spot=self.use_spot, ) ) if Path("~/.rh").expanduser().exists(): task.set_file_mounts( { "~/.rh": "~/.rh", } ) # If we choose to reduce collisions of cluster names: # cluster_name = self.rns_address.strip('~/').replace("/", "-") sky.launch( task, cluster_name=self.name, idle_minutes_to_autostop=self.autostop_mins, down=True, ) elif self.provider == "k8s": raise NotImplementedError("Kubernetes Cluster provider not yet supported") else: raise ValueError(f"Cluster provider {self.provider} not supported.") self._update_from_sky_status() self.restart_server(restart_ray=True) return self
[docs] def keep_warm(self, autostop_mins: int = -1): """Keep the cluster warm for given number of minutes after inactivity. Args: autostop_mins (int): Amount of time (in min) to keep the cluster warm after inactivity. If set to -1, keep cluster warm indefinitely. (Default: `-1`) """ sky.autostop(self.name, autostop_mins, down=True) self.autostop_mins = autostop_mins return self
[docs] def teardown(self): """Teardown cluster. Example: >>> rh.ondemand_cluster("rh-cpu").teardown() """ # Stream logs sky.down(self.name) self.address = None
[docs] def teardown_and_delete(self): """Teardown cluster and delete it from configs. Example: >>> rh.ondemand_cluster("rh-cpu").teardown_and_delete() """ self.teardown() rns_client.delete_configs()
[docs] @contextlib.contextmanager def pause_autostop(self): """Context manager to temporarily pause autostop. Example: >>> with rh.ondemand_cluster.pause_autostop(): >>> rh.ondemand_cluster.run(["python train.py"]) """ sky.autostop(self.name, idle_minutes=-1) yield sky.autostop(self.name, idle_minutes=self.autostop_mins)
# ----------------- SSH Methods ----------------- #
[docs] @staticmethod def cluster_ssh_key(path_to_file): """Retrieve SSH key for the cluster. Example: >>> ssh_priv_key = rh.ondemand_cluster("rh-cpu").cluster_ssh_key("~/.ssh/id_rsa") """ try: f = open(path_to_file, "r") private_key = f.read() return private_key except FileNotFoundError: raise Exception(f"File with ssh key not found in: {path_to_file}")
[docs] def ssh_creds(self): """Retrieve SSH creds for the cluster. Example: >>> credentials = rh.ondemand_cluster("rh-cpu").ssh_creds() """ if self._ssh_creds: return self._ssh_creds if not self.status(refresh=False) and self.sky_state: # If this cluster was serialized and sent over the wire, it will have sky_state (we make sure of that # in __getstate__) but no yaml, and we need to save down the sky data to the sky db and local yaml self._save_sky_state() else: # To avoid calling this twice (once in save_sky_data) self._update_from_sky_status(dryrun=True) return self._ssh_creds
[docs] def ssh(self): """SSH into the cluster. Example: >>> rh.ondemand_cluster("rh-cpu").ssh() """ subprocess.run(["ssh", f"{self.name}"])