This example demonstrates how to use Argo to orchestrate the training of a basic Torch model, with the training class dispatched to a GPU-enabled AWS cloud instance to actually do the training.
We use the very popular MNIST dataset which includes a large number of handwritten digits, and create a neural network that accurately identifies what digit is in an image.
For this example, we will need AWS cloud credentials and a Runhouse API key. We name this secret my-secret
for simplicity, and use it in the pipeline.
kubectl create secret generic my-secret
--from-literal=AWS_ACCESS_KEY_ID=<your-access-key-id>
--from-literal=AWS_SECRET_ACCESS_KEY=<your-secret-access-key>
--from-literal=RUNHOUSE_API_KEY=<your-runhouse-api-key>
We'll be launching elastic compute from AWS from within the first step and using the same compute across steps, but you can use any compute resource. It's important to note that reusing the cluster provides a lot of benefits, such as statefulness and minimized I/O overhead. You can see in the multi-cloud example that you can even run steps on different clusters - for instance, run CPU pre-processing on the cluster that hosts Argo, while offloading GPU to an elastic instance.
The code actually executed by the pipeline is extremely lean, with each step being a function that is run on a remote cluster. The functions are defined in the TorchBasicExample module, and are sent to the remote cluster using Runhouse. For the sake of simplicity, we have included the code for the functions in the pipeline, but you would likely define these Runhouse commands in a separate file or container.
apiVersion: argoproj.io/v1alpha1 kind: Workflow metadata: generateName: pytorch-training-pipeline- spec: entrypoint: pytorch-training-pipeline templates: - name: pytorch-training-pipeline steps: - - name: bring-up-cluster template: bring-up-cluster-task - - name: access-data template: access-data-task - - name: train-model template: train-model-task - - name: down-cluster template: down-cluster-task # First, we can bring up an on-demand cluster using Runhouse. You can access powerful usage patterns by defining compute in code. All subsequent steps connect to this cluster by name, but you can also bring up other clusters for other steps. - name: bring-up-cluster-task script: image: pypypypy/my-pipeline-image:latest command: [python] env: - name: AWS_ACCESS_KEY_ID valueFrom: secretKeyRef: name: my-secret key: AWS_ACCESS_KEY_ID - name: AWS_SECRET_ACCESS_KEY valueFrom: secretKeyRef: name: my-secret key: AWS_SECRET_ACCESS_KEY - name: RUNHOUSE_API_KEY valueFrom: secretKeyRef: name: my-secret key: RUNHOUSE_API_KEY source: | # We show the code here for simpler illustration of the workflow, but Runhouse does not require any special setup. Use scripts or containers. import os, runhouse as rh # First we configure the environment to setup Runhouse and AWS credentials. We only need to configure the AWS credentials in the first step since the cluster is saved to Runhouse and we reuse the resource. rh.login(token=os.getenv("RUNHOUSE_API_KEY"), interactive=False) import subprocess subprocess.run( [ "aws", "configure", "set", "aws_access_key_id", os.getenv("AWS_ACCESS_KEY_ID"), ], check=True, ) subprocess.run( [ "aws", "configure", "set", "aws_secret_access_key", os.getenv("AWS_SECRET_ACCESS_KEY"), ], check=True, ) print(os.getcwd()) # Now we bring up the cluster and save it to Runhouse to reuse in subsequent steps. This allows for reuse of the same compute and much better statefulness across multiple Argo steps. cluster = rh.ondemand_cluster(name="rh-a10g-torchtrain", instance_type="A10G:1", provider="aws", autostop_mins=90).up_if_not() print(cluster.is_up()) cluster.save() # This step represents a step to access and lightly preprocess the dataset. The MNIST example is trivial, but it is worth calling out that we are doing this preprocessing on the same compute we will use later to do the training and we do not need to re-access the data or re-download it. - name: access-data-task script: image: pypypypy/my-pipeline-image:latest command: [python] env: - name: RUNHOUSE_API_KEY valueFrom: secretKeyRef: name: my-secret key: RUNHOUSE_API_KEY source: | import sys, os, runhouse as rh sys.path.append(os.path.expanduser("~/training")) from TorchBasicExample import download_data, preprocess_data rh.login(token=os.getenv("RUNHOUSE_API_KEY"), interactive=False) env = rh.env(name="test_env", reqs=["torch", "torchvision"]) # I am adding /paul/ since I have saved the cluster to Runhouse with my account. cluster = rh.cluster(name="/paul/rh-a10g-torchtrain") cluster.is_up() remote_download = rh.function(download_data).to(cluster, env=env) remote_preprocess = rh.function(preprocess_data).to(cluster, env=env) remote_download() remote_preprocess("./data") # Now we run the training. In this step, we dispatch the training to the remote cluster. The model is trained on the remote cluster, and the model checkpoints are saved to an S3 bucket. - name: train-model-task script: image: pypypypy/my-pipeline-image:latest command: [python] env: - name: RUNHOUSE_API_KEY valueFrom: secretKeyRef: name: my-secret key: RUNHOUSE_API_KEY source: | import sys, os, runhouse as rh sys.path.append(os.path.expanduser("~/training")) from TorchBasicExample import SimpleTrainer rh.login(token=os.getenv("RUNHOUSE_API_KEY"), interactive=False) cluster = rh.cluster(name="/paul/rh-a10g-torchtrain").up_if_not() env = rh.env(name="test_env", reqs=["torch", "torchvision"]) remote_torch_example = rh.module(SimpleTrainer).to( cluster, env=env, name="torch-basic-training" ) model = remote_torch_example() batch_size = 64 epochs = 5 learning_rate = 0.01 model.load_train("./data", batch_size) model.load_test("./data", batch_size) for epoch in range(epochs): model.train_model(learning_rate=learning_rate) model.test_model() model.save_model( bucket_name="my-simple-torch-model-example", s3_file_path=f"checkpoints/model_epoch_{epoch + 1}.pth", ) # Finally, we can down the cluster after the training is done. - name: down-cluster-task script: image: pypypypy/my-pipeline-image:latest command: [python] env: - name: RUNHOUSE_API_KEY valueFrom: secretKeyRef: name: my-secret key: RUNHOUSE_API_KEY source: | import os, runhouse as rh rh.login(token=os.getenv("RUNHOUSE_API_KEY"), interactive=False) cluster = rh.cluster(name="/paul/rh-a10g-torchtrain") cluster.teardown()