Ever waited too long for a model to return predictions? We have all been there. Machine learning models, especially the large, complex ones, can be painfully slow to serve in real time. Users, on the other hand, expect instant feedback. That’s where latency becomes a real problem. Technically speaking, one of the biggest problems is redundant computation when the same input triggers the same slow process repeatedly. In this blog, I’ll show you how to fix that. We will build a FastAPI-based ML service and integrate Redis caching to return repeated predictions in milliseconds.
FastAPI is a modern, high-performance web framework for building APIs with Python. It uses Python‘s type hints for data validation and automatic generation of interactive API documentation using Swagger UI and ReDoc. Built on top of Starlette and Pydantic, FastAPI supports asynchronous programming, making it comparable in performance to Node.js and Go. Its design facilitates rapid development of robust, production-ready APIs, making it an excellent choice for deploying machine learning models as scalable RESTful services.
Redis (Remote Dictionary Server) is an open-source, in-memory data structure store that functions as a database, cache, and message broker. By storing data in memory, Redis offers ultra-low latency for read and write operations, making it ideal for caching frequent or computationally intensive tasks like machine learning model predictions. It supports various data structures, including strings, lists, sets, and hashes, and provides features like key expiration (TTL) for efficient cache management.
Integrating FastAPI with Redis creates a system that is both responsive and efficient. FastAPI serves as a swift and reliable interface for handling API requests, while Redis acts as a caching layer to store the results of previous computations. When the same input is received again, the result can be retrieved instantly from Redis, bypassing the need for recomputation. This approach reduces latency, lowers computational load, and enhances the scalability of your application. In distributed environments, Redis serves as a centralised cache accessible by multiple FastAPI instances, making it an excellent fit for production-grade machine learning deployments.
Now, let’s walk through the implementation of a FastAPI application that serves machine learning model predictions with Redis caching. This setup ensures that repeated requests with the same input are served quickly from the cache, reducing computation time and improving response times. The steps are mentioned below:
Now, let’s see these steps in more detail.
First, assume that you already have a trained machine learning model that is ready to deploy. In practice, most of the models are trained offline (like a scikit-learn model, a TensorFlow/Pytorch model, etc), saved to disk, and then loaded into a serving app. For our example, we will create a simple scikit-learn classifier that will be trained on the famous Iris flower dataset and saved using joblib. If you already have a saved model file, you can skip the training part and just load it. Here’s how to train a model and then load it for serving:
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
import joblib
# Load example dataset and train a simple model (Iris classification)
X, y = load_iris(return_X_y=True)
# Train the model
model = RandomForestClassifier().fit(X, y)
# Save the trained model to disk
joblib.dump(model, "model.joblib")
# Load the pre-trained model from disk (using the saved file)
model = joblib.load("model.joblib")
print("Model loaded and ready to serve predictions.")
In the above code, we have used scikit-learn’s built-in Iris dataset, trained a random forest classifier on it, and then saved that model to a file called model.joblib. After that, we have loaded it back using joblib.load. The joblib library is pretty common when it comes to saving scikit-learn models, mostly because it is good at handling NumPy arrays inside models. After this step, we have a model object ready to predict on new data. Just a heads-up, though, you can use any pre-trained model here, the way you serve it using FastAPI, and also cached results would be more or less the same. The only thing is, the model should have a predict method that takes in some input and produces the result. Also, make sure that the model’s prediction stays the same every time you give it the same input (so it’s deterministic). If it’s not, caching would be problematic for non-deterministic models as it would return incorrect results.
Now that we have a model, let’s use it via API. We will be using FASTAPI to create a web server that attends to prediction requests. FASTAPI makes it easy to define an endpoint and map request parameters to Python function arguments. In our example, we will assume the model accepts four features. And will create a GET endpoint /predict
that accepts these features as query parameters and returns the model’s prediction.
from fastapi import FastAPI
import joblib
app = FastAPI()
# Load the trained model at startup (to avoid re-loading on every request)
model = joblib.load("model.joblib") # Ensure this file exists from the training step
@app.get("/predict")
def predict(sepal_length: float, sepal_width: float, petal_length: float, petal_width: float):
""" Predict the Iris flower species from input measurements. """
# Prepare the features for the model as a 2D list (model expects shape [n_samples, n_features])
features = [[sepal_length, sepal_width, petal_length, petal_width]]
# Get the prediction (in the iris dataset, prediction is an integer class label 0,1,2 representing the species)
prediction = model.predict(features)[0] # Get the first (only) prediction
return {"prediction": str(prediction)}
In the above code, we have made a FastAPI app, and upon executing the file, it starts the API server. FastAPI is super fast for Python, so it can handle lots of requests easily. Then we load the model just at the start because doing it again and again on every request would be slow, so we keep it in memory, which is ready to use. We created a /predict
endpoint with @app.get
, GET makes testing easy since we can just pass things in the URL, but in real projects, you will probably want to use POST, especially if sending big or complex input like images or JSON. This function takes 4 inputs: sepal_length
, sepal_width
, petal_length
, and petal_width
, and FastAPI auto reads them from the URL. Inside the function, we put all the inputs into a 2D list (because scikit-learn accepts only a 2D array), then we call model.predict()
, and it gives us a list. Then we return it as JSON like { “prediction”: “...”}
.
Therefore, now it works, you can run it using uvicorn main:app --reload
, hit /predict
, endpoint and get results. Even if you send the same input again, it still runs the model again, which is not good, so the next step is adding Redis to cache the previous results and skip redoing them.
To cache the model output, we will be using Redis. First, make sure the Redis server is running. You can install it locally or just run a Docker container; it usually runs on port 6379 by default. We will be using the Python redis library to talk to the server.
So the idea is simple: when a request comes in, create a unique key that represents the input. Then check if the key exists in Redis; if that key is already there, which means we already cached this before, so we just return the saved result, no need to call the model again. If not there, we do model.predict
, get the output, save it in Redis, and send back the prediction.
Let’s now update the FastAPI app to add this cache logic.
!pip install redis
import redis # New import to use Redis
# Connect to a local Redis server (adjust host/port if needed)
cache = redis.Redis(host="localhost", port=6379, db=0)
@app.get("/predict")
def predict(sepal_length: float, sepal_width: float, petal_length: float, petal_width: float):
"""
Predict the species, with caching to speed up repeated predictions.
"""
# 1. Create a unique cache key from input parameters
cache_key = f"{sepal_length}:{sepal_width}:{petal_length}:{petal_width}"
# 2. Check if the result is already cached in Redis
cached_val = cache.get(cache_key)
if cached_val:
# If cache hit, decode the bytes to a string and return the cached prediction
return {"prediction": cached_val.decode("utf-8")}
# 3. If not cached, compute the prediction using the model
features = [[sepal_length, sepal_width, petal_length, petal_width]]
prediction = model.predict(features)[0]
# 4. Store the result in Redis for next time (as a string)
cache.set(cache_key, str(prediction))
# 5. Return the freshly computed prediction
return {"prediction": str(prediction)}
In the above code, we added Redis now. First, we made a client using redis.Redis()
. It connects to the Redis server. Using db=0 by default. Then we created a cache key just by joining the input values. Here it works because the inputs are simple numbers, but for complex ones it’s better to use a hash or a JSON string. The key must be unique for each input. We have used cache.get(cache_key)
. If it finds the same key, it returns that, which makes it fast, and with this, there is no need to rerun the model. But if it is not found in the cache, we need to run the model and get the prediction. Finally, save that in Redis using cache.set()
. So next time, when the same input comes, it’s already there, and caching would be fast.
Now that our FastAPI app is running and is connected to Redis, it’s time for us to test how caching improves the response time. Here, I will demonstrate how to use Python’s requests library to call the API twice with the same input and measure the time taken for each call. Also, make sure that you start your FastAPI before running the test code:
import requests, time
# Sample input to predict (same input will be used twice to test caching)
params = {
"sepal_length": 5.1,
"sepal_width": 3.5,
"petal_length": 1.4,
"petal_width": 0.2
}
# First request (expected to be a cache miss, will run the model)
start = time.time()
response1 = requests.get("http://localhost:8000/predict", params=params)
elapsed1 = time.time() - start
print("First response:", response1.json(), f"(Time: {elapsed1:.4f} seconds)")
# Second request (same params, expected cache hit, no model computation)
start = time.time()
response2 = requests.get("http://localhost:8000/predict", params=params)
elapsed2 = time.time() - start
print("Second response:", response2.json(), f"(Time: {elapsed2:.6f}seconds)")
When you run this, you should see the first request return a result. Then the second request returns the same result, but noticeably faster. For example, you might find the first call took on the order of tens of milliseconds (depending on model complexity), while the second call might be a few milliseconds or less. In our simple demo with a lightweight model, the difference might be small (since the model itself is fast), but the effect is drastic for heavier models.
To put this into perspective, let’s consider what we achieved:
In real experiments, caching can lead to order-of-magnitude improvements. In e-commerce, for example, using Redis meant returning recommendations in microseconds for repeat requests, versus having to recompute them with the full model serve pipeline. The performance gain will depend on how expensive your model inference is. The more complex the model, the more you benefit from caching on repeated calls. It also depends on request patterns: if every request is unique, the cache won’t help (no repeats to serve from memory), but many applications do see overlapping requests (e.g., popular search queries, recommended items, etc.).
You can also check your Redis cache directly to verify it’s storing keys.
In this blog, we demonstrated how FastAPI and Redis can work in collaboration to accelerate ML model serving. FastAPI provides a fast and easy-to-build API layer for serving predictions, and Redis adds a caching layer that significantly reduces latency and CPU load for repeated computations. By avoiding repeated model calls, we have improved responsiveness and also enabled the system to handle more requests with the same resources.