This example defines a retrieval augmented generation (RAG) app that references text from websites to enrich the response from an LLM. Depending on the URLs you use to populate the vector database, you'll be able to answer questions more intelligently with relevant context.
You could, for example, pass in the URL of a new startup and then ask the app to answer questions about that startup using information on their site. Or pass in several specialized gardening websites and ask nuanced questions about horticulture.
Deploy a FastAPI app that is able to create and store embeddings from text on public website URLs, and generate answers to questions using related context from stored websites and an open source LLM.
Note: Some of the steps in this example could also be accomplished with platforms like OpenAI and tools such as LangChain, but we break out the components explicitly to fully illustrate each step and make the example easily adaptible to other use cases. Swap out components as you see fit!
Runhouse allows you to turn complex operations such as preprocessing and inference into independent services. Servicifying with Runhouse enables:
We chose FastAPI as our platform because of its popularity and simplicity. However, we could
easily use any other Python-based platform with Runhouse. Streamlit, Flask, <your_favorite>
, we got you!
To ensure that Runhouse is able to manage deploying services to your cloud provider (AWS in this case) you may need to follow initial setup steps. Please visit the AWS section of our Installation Guide
Additionally, we'll be downloading the Llama 3 model from Hugging Face, so we need to set up our Hugging Face token:
$ export HF_TOKEN=<your huggingface token>
Make sure to sign the waiver on the Hugging Face model page so that you can access it.
First, we'll import necessary packages and initialize variables used in the application. The URLEmbedder
and
LlamaModel
classes that will be sent to Runhouse are available in the app/modules
folder in this source code.
from contextlib import asynccontextmanager from typing import Dict, List import lancedb import runhouse as rh from fastapi import Body, FastAPI, HTTPException from app.modules.embedding import Item, URLEmbedder from app.modules.llm import LlamaModel EMBEDDER, TABLE, LLM = None, None, None DEBUG = True # In DEBUG mode we will always override Runhouse modules
Define configuration options for our remote cluster. If you prefer to use a different
cloud provider, be sure to install the appropriate version of Runhouse: e.g.
runhouse[aws]
or runhouse[gcp]
.
CLUSTER_NAME = "rh-xa10g" # Allows the cluster to be reused GPUS = "A10G:1" # A10G GPU to handle LLM iinference CLOUD_PROVIDER = "aws" # Alternatively "gcp", "azure", or "cheapest"
Template to be used in the LLM generation phase of the RAG app
PROMPT_TEMPLATE = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. Use three sentences maximum and keep the answer as concise as possible. Always say "thanks for asking!" at the end of the answer. {context} Question: {question} Helpful Answer: """
We'll use the lifespan
argument of the FastAPI app to initialize our embedding service,
vector database, and LLM service when the application is first created.
@asynccontextmanager async def lifespan(app): global EMBEDDER, TABLE, LLM EMBEDDER = load_embedder() TABLE = load_table() LLM = load_llm() yield
This method, run during initialization, will provision a remote machine (A10G on AWS in this case)
and deploy our URLEmbedder
to that machine.
The Python packages required by the embedding service (langchain
etc.) are defined on the
Runhouse image so that they can be properly installed
on the cluster. Also note that they are imported inside class methods. Those packages do not need
to be installed locally.
def load_embedder(): """Launch an A10G and send the embedding service to it.""" img = rh.Image("embedder_img").install_packages( [ "langchain", "langchain-community", "langchain_text_splitters", "langchainhub", "bs4", "sentence_transformers", "torch", ], ) cluster = rh.cluster( name=CLUSTER_NAME, gpus=GPUS, provider=CLOUD_PROVIDER, image=img ).up_if_not() module_name = "url_embedder" remote_url_embedder = cluster.get(module_name, default=None, remote=True) if DEBUG or remote_url_embedder is None: RemoteEmbedder = rh.module(URLEmbedder).to(system=cluster, name="URLEmbedder") remote_url_embedder = RemoteEmbedder( model_name_or_path="BAAI/bge-large-en-v1.5", device="cuda", name=module_name ) return remote_url_embedder
We'll be using open source LanceDB to create an embedded database to store the URL embeddings and perform vector search for the retrieval phase. You could alternatively try Chroma, Pinecone, Weaviate, or even MongoDB.
def load_table(): # Initialize LanceDB database directly on the FastAPI app's machine db = lancedb.connect("/tmp/db") return db.create_table("rag-table", schema=Item.to_arrow_schema(), exist_ok=True)
Deploy an open LLM, Llama 3 in this case, to a GPU on the cloud provider of your choice. We will use vLLM to serve the model due to it's high performance and throughput but there are many other options such as HuggingFace Transforms and TGI.
Here we leverage the same A10G cluster we used for the embedding service, but you could also spin up a new remote machine specifically for the LLM service. Alternatively, use a proprietary model like ChatGPT or Claude.
def load_llm(): """Use the existing A10G cluster to run an LLM inference service""" # Specifying the same name will reuse our embedding service cluster img = ( rh.Image("llama3_inference") .install_packages(["torch", "vllm==0.5.4"]) .sync_secrets(["huggingface"]) ) cluster = rh.cluster( CLUSTER_NAME, gpus=GPUS, provider=CLOUD_PROVIDER, image=img ).up_if_not() module_name = "llama_model" # First check for an instance of the LlamaModel stored on the cluster remote_llm = cluster.get(module_name, default=None, remote=True) if DEBUG or remote_llm is None: # If not found (or debugging) sync up the model and create a fresh instance RemoteLlama = rh.module(LlamaModel).to(system=cluster, name="LlamaModel") remote_llm = RemoteLlama(name=module_name) return remote_llm
Before defining endpoints, we'll initialize the application and set the lifespan events defined above. This will load in the various services we've defined on start-up.
app = FastAPI(lifespan=lifespan)
Add an endpoint to check on our app health. This is a minimal example intended to only show if the application is up and running or down.
@app.get("/health") def health_check(): return {"status": "healthy"}
To illustrate the flexibility of FastAPI, we're allowing embeddings to be added to your database via a POST endpoint. This method will use the embedder service to create database entries with the source, content, and vector embeddings for chunks of text from a provided list of URLs.
@app.post("/embeddings") async def generate_embeddings(paths: List[str] = Body([]), kwargs: Dict = Body({})): """Generate embeddings for the URL and write to DB.""" try: items = await EMBEDDER.embed_docs( paths, normalize_embeddings=True, run_async=True, stream_logs=False, **kwargs, ) except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to embed URLs: {str(e)}") items = [Item(**item) for item in items] TABLE.add(items) return {"status": "success"}
Now that we've defined our services and created an endpoint to populate data for retrieval, the remaining components of the application will focus on the generative phases of the RAG app.
In the retrieval phase we'll first use the Embedder service to create an embedding from input text to search our LanceDB vector database with. LanceDB is optimized for vector searches in this manner.
async def retrieve_documents(text: str, limit: int) -> List[Item]: """Retrieve documents from vector DB related to input text""" try: # Encode the input text into a vector vector = await EMBEDDER.encode_text( text, normalize_embeddings=True, run_async=True, stream_logs=False ) # Search LanceDB for nearest neighbors to the vector embed results = TABLE.search(vector).limit(limit).to_pydantic(Item) return results except Exception as e: raise HTTPException( status_code=500, detail=f"Failed to retrieve documents: {str(e)}" )
To leverage the documents retrieved from the previous step, we'll format a prompt that provides text from related documents as "context" for the LLM. This allows a general purpose LLM (like Llama) to provide more specific responses to a particular question.
async def format_prompt(text: str, docs: List[Item]) -> str: """Retrieve documents from vector DB related to input text""" context = "\n".join([doc.page_content for doc in docs]) prompt = PROMPT_TEMPLATE.format(question=text, context=context) return prompt
Using the methods above, this endpoint will run inference on our LLM to generate a response to a question. The results are enhanced by first retrieving related documents from the source URLs fed into the POST endpoint. Content from the fetched documents is then formatted into the text prompt sent to our self-hosted LLM. We'll be using a generic prompt template to illustrate how many "chat" tools work behind the scenes.
@app.get("/generate") async def generate_response(text: str, limit: int = 4): """Generate a response to a question using an LLM with context from our database""" if not text: return {"error": "Question text is missing"} try: # Retrieve related documents from vector DB documents = await retrieve_documents(text, limit) # List of sources from retrieved documents sources = set([doc.url for doc in documents]) # Create a prompt using the documents and search text prompt = await format_prompt(text, documents) # Send prompt with optional sampling parameters for vLLM # More info: https://github.com/vllm-project/vllm/blob/main/vllm/sampling_params.py response = await LLM.generate( prompt=prompt, temperature=0.8, top_p=0.95, max_tokens=100 ) return {"question": text, "response": response, "sources": sources} except Exception as e: raise HTTPException( status_code=500, detail=f"Failed to generate response: {str(e)}" )
Use the following command to run the app from your terminal:
$ fastapi run app/main.py
After a few minutes, you can navigate to http://127.0.0.1/health
to check that your application is running.
This may take a while due to initialization logic in the lifespan.
You'll see something like:
{ "status": "healthy" }
To debug the application, you may prefer running fastapi dev
. This will trigger
automatic re-deployments from any changes to your code. Be sure to set DEBUG
to True
to
override instances of the embedding and LLM services with updated versions.
To populate the LanceDB database with vector embeddings for use in the RAG app, you can send a HTTP request
to the /embeddings
POST endpoint. Let's say you have a question about bears. You could send a cURL
command with a list of URLs including essential bear information:
curl --header "Content-Type: application/json" \ --request POST \ --data '{"paths":["https://www.nps.gov/yell/planyourvisit/safety.htm", "https://en.wikipedia.org/wiki/Adventures_of_the_Gummi_Bears"]}' \ http://127.0.0.1:8000/embeddings
Alternatively, we recommend a tool like Postman to test HTTP APIs.
Open your browser and send a prompt to your locally running RAG app by appending your question
to the URL as a query param, e.g. ?text=Does%20yellowstone%20have%20gummi%20bears%3F
"http://127.0.0.1/generate?text=Does%20yellowstone%20have%20gummi%20bears%3F"
The LlamaModel
will need to load on the initial call and may take a few minutes to generate a
response. Subsequent calls will generally take less than a second.
Example output:
{ "question": "Does yellowstone have gummi bears?", "response": [ " No, Yellowstone is bear country, not gummi bear country. Thanks for asking! " ], "sources": [ "https://www.nps.gov/yell/planyourvisit/safety.htm", "https://en.wikipedia.org/wiki/Adventures_of_the_Gummi_Bears" ] }
There are any methods to deploy a FastAPI application to a production environment. With some modifications
to the logic of the app, setting DEBUG
to False
, and deploying to a public IP, this example
could easily serve as the backend to a RAG app.
We won't go into depth on a specific method, but here are a few things to consider:
sync_secrets
to grant permissions to use
Llama 3. Make sure you handle this on your server as well.sky
commands make it easy to manage remote clusters locally
but you may also want to monitor your cloud provider to avoid unused GPUs running up a bill.If you're running into any problems using Runhouse in production, please reach out to our team at team@run.house. We'd be happy to set up a time to help you debug live. Additionally, you can chat with us directly on Discord.