This example demonstrates how to use Airflow along with Runhouse to dispatch the work of training a basic Torch model to a remote GPU. The Airflow pipeline and all the callables can be run from anywhere, including local, but it will bring up an cluster on AWS with a GPU and send the training job there.
The Torch model is identical to our Torch training example, but placed within the context of an orchestrator to show how Runhouse allows for more flexible debugging and dispatch when used in conjunction with a traditional orchestrator. The model code uses the very popular MNIST dataset which includes a large number of handwritten digits, and create a neural network that accurately identifies what digit is in an image.
Optionally, set up a virtual environment:
$ conda create -n demo-runhouse python=3.10 $ conda activate demo-runhouse
Install the required dependencies:
$ pip install "runhouse[aws]" torch torchvision airflow
We'll be launching an AWS EC2 instance via SkyPilot, so we need to make sure our AWS credentials are set up:
$ aws configure $ sky check
Import some other libraries we need - namely Airflow, Runhouse, and a few others.
import logging import os
This class is the same as the example in https://www.run.house/examples/torch-vision-mnist-basic-model-train-test
import sys from datetime import datetime, timedelta import runhouse as rh from airflow import DAG from airflow.operators.bash_operator import BashOperator from airflow.operators.python import PythonOperator sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from TorchBasicExample import download_data, preprocess_data, SimpleTrainer logger = logging.getLogger(__name__)
These will be called in sequence by the Airflow PythonOperator. Each task in the Airflow DAG becomes minimal. These callables define both what tasks are run and also where the tasks are run - essentially, programatically controlling the dispatch.
We can bring up an on-demand cluster using Runhouse. You can access powerful usage patterns by defining compute in code. All subsequent steps connect to this cluster by name, but you can bring up other clusters for other steps.
def bring_up_cluster_callable(**kwargs): logger.info("Connecting to remote cluster") cluster = rh.ondemand_cluster( name="a10g-cluster", instance_type="A10G:1", provider="aws" ).up_if_not() print(cluster.is_up()) # cluster.save() ## Use if you have a Runhouse Den account to save and monitor the resource.
We will send the function to download data to the remote cluster and then invoke it to download the data to the remote machine. You can imagine that this is a data access or pre-processing step after which data is prepared.
def access_data_callable(**kwargs): logger.info("Step 2: Access data") env = rh.env(name="test_env", reqs=["torch", "torchvision"]) cluster = rh.cluster(name="a10g-cluster").up_if_not() remote_download = rh.function(download_data).to(cluster, env=env) remote_preprocess = rh.function(preprocess_data).to(cluster, env=env) logger.info("Download function sent to remote") remote_download() remote_preprocess() logger.info("Downloaded")
Then we instantiate the trainer, and then invoke the training on the remote machine. On the remote, we have a GPU. This is also a natural point to split the workflow if we want to do some tasks on GPU and some on CPU.
def train_model_callable(**kwargs): logger.info("Step 3: Train Model") cluster = rh.cluster(name="a10g-cluster").up_if_not() env = rh.env(name="test_env", reqs=["torch", "torchvision"]) remote_torch_example = rh.module(SimpleTrainer).to( cluster, env=env, name="torch-basic-training" ) model = remote_torch_example() batch_size = 64 epochs = 5 learning_rate = 0.01 model.load_train("./data", batch_size) model.load_test("./data", batch_size) for epoch in range(epochs): model.train_model(learning_rate=learning_rate) model.test_model() model.save_model( bucket_name="my-simple-torch-model-example", s3_file_path=f"checkpoints/model_epoch_{epoch + 1}.pth", )
We programatically down the cluster, but we can also reuse this cluster by name.
def down_cluster(**kwargs): cluster = rh.cluster(name="a10g-cluster") cluster.teardown()
This is a simple DAG with multiple steps. Each step is a PythonOperator that calls a function defined above.
default_args = { "owner": "paul", "depends_on_past": False, "start_date": datetime(2024, 8, 6), "email_on_failure": False, "email_on_retry": False, "retries": 1, "retry_delay": timedelta(minutes=5), } dag = DAG( "pytorch_training_pipeline_example", default_args=default_args, description="A simple PyTorch training DAG with multiple steps", schedule=timedelta(days=1), ) run_sky_status = BashOperator( task_id="run_sky_status", bash_command="sky status", dag=dag, ) bring_up_cluster_task = PythonOperator( task_id="bring_up_cluster_task", python_callable=bring_up_cluster_callable, dag=dag, ) access_data_task = PythonOperator( task_id="access_data_task", python_callable=access_data_callable, dag=dag, ) train_model_task = PythonOperator( task_id="train_model_task", python_callable=train_model_callable, dag=dag, ) down_cluster_task = PythonOperator( task_id="down_cluster_task", python_callable=down_cluster, dag=dag, )
You can see that this is an incredibly minimal amount of code in Airflow. The callables are callable from the DAG. But you can also run them from a Python script, from a notebook, or anywhere else - so you can instantly iterate on the underlying classes, the functions, and by the time they run locally, they are ready for prime time in your DAG.
( run_sky_status >> bring_up_cluster_task >> access_data_task >> train_model_task >> down_cluster_task )