Triton. Part 4. Serving RF-DETR

- 8 mins read

Series: Triton

TL;DR: Give Me The Code

If you just want to see the code we discuss today, head over to my GitHub repository at https://github.com/yevhen-k/triton-tutorials and check out these files::

  • client-rfdetr-pytorch.py
  • client-rfdetr-pytorch-async.py
  • models/rfdetr-pytorch/

Introduction

Today, we’re taking a deep dive into several powerful Triton features:

  1. Serving on GPU
  2. Model Versioning
  3. Batching
  4. Asynchronous Serving
  5. Passing Custom Parameters

By now, you already have a solid foundation: you know how to send and receive custom data, configure the server using config.pbtxt, start Triton with custom Python logic, and manage data flow.

Prerequesites:

  • Be familiar with the RF-DETR model and how to run it 1.
  • Have RF-DETR installed in your Docker image (we covered how to build this image in the previous post).
  • Be familiar with asynchronous programming in Python 2.
  • Have access to a suitable NVIDIA GPU for model acceleration.

Our Plan for Today

We’re going to build a complete inference pipeline:

  1. Send image to the Triton server
  2. Run inference on the image using RF-DETR to detect bounding boxes for specific classes.
  3. Return a JSON object containing the detection response to the client.

Model Directory Internals

Let’s take a look inside our model’s directory. Here is the structure:

models
└── rfdetr-pytorch
    ├── 1
    │   └── model.py
    ├── 2
    │   └── model.py
    ├── config.pbtxt
    └── rf-detr-medium.pth

You’ll notice two key differences here compared to what we’ve covered so far:

  1. The model’s weights file (rf-detr-medium.pth) is located right next to the config.pbtxt file.
  2. There are two separate folders, 1 and 2, each named after a model version number.

Sync and Async Model Serving

You’ll notice we have two model.py versions. The only real difference between them is that the first is set up for synchronous (sync) serving and the second for asynchronous (async) serving.

To configure your Triton model to run in async mode, all you need to do is add the async keyword to the execute() function signature:

async def execute(
        self, requests: "List[pb_utils.InferenceRequest]"
    ) -> "List[pb_utils.InferenceResponse]":
    ...

It’s that simple! The difference between the two model versions is just that single keyword. Everything else remains identical.

How Does config.pbtxt Manage Model Versions?

To handle the different versions of your model, the Triton server relies on the version_policy setting within the config.pbtxt file 3:

  • All: All versions of the model that are available in the model repository are available for inferencing. version_policy: { all: {}}
  • Latest: Only the latest ‘n’ versions of the model in the repository are available for inferencing. The latest versions of the model are the numerically greatest version numbers. version_policy: { latest: { num_versions: 2}}
  • Specific: Only the specifically listed versions of the model are available for inferencing. version_policy: { specific: { versions: [1,3]}}

For our simple setup, we’re using the version_policy: { all: {}} setting. This tells Triton to load and serve all versions of the model that it finds in the directory (in our case, versions 1 and 2).

How to Configure config.pbtxt to Serve Models on GPU?

To set up Triton to use GPU we use the following settings in the config.pbtxtfile 4:

instance_group [
    {
        kind: KIND_GPU
        count: 1
    }
]

How to Configure config.pbtxt for Batch Inference?

Triton offers different batching mechanisms for various use cases 5. For our stateless RF-DETR model, dynamic batching is the perfect fit.

The great thing is that Triton Server handles the batch allocation, not you. To create a batch, the server collects several individual requests and combines them into a single batch. Since requests don’t arrive simultaneously, it takes a short time for the server to collect enough requests to form an efficient batch.

Let’s configure Triton to use a maximum batch size of 32, with preferred sizes of 8 and 16, and set the maximum delay for collecting requests to 5,000 microseconds (5 milliseconds).

In the config.pbtxt, this configuration looks like this:

max_batch_size: 32
dynamic_batching {
    preferred_batch_size: [8, 16]
    max_queue_delay_microseconds: 5000
}

Inputs and Outputs

As we mentioned, we’re going to send an image to the server and receive a JSON with the object detections as the response.

Since we are using batching and the tensor size of the raw image bytes isn’t fixed, the inputs and outputs in our config.pbtxt will look like this:

input [
    {
        name: "in:jpg"
        data_type: TYPE_UINT8
        dims: [ -1 ]
    }
]

output [
    {
        name: "detections:json"
        data_type: TYPE_UINT8
        dims: [ -1 ]
    }
]

Notice the -1 dimension in the output. This is a crucial detail when using batching, as it indicates that for a batched input, we must prepare a batched output. The first dimension of -1 is Triton’s way of saying, “The output will have a variable batch size.”

Passing Custom Parameters to the config.pbtxt

Parameters are configuration values you want to pass directly to your model.py file. It’s easiest to understand this with a real-world example.

In our RF-DETR case, we need to filter the model’s predictions based on a class list and a confidence threshold. Let’s say we only want to detect person and car classes with a threshold of 0.3.

We have a few ways to handle this configuration, but not all of them are ideal:

  • Send a Config with the Image: You could send a JSON config along with the image. However, this means both the client and server must agree on a custom JSON structure.
  • Filter on the Client Side: The client could filter the model’s raw response, but this forces the client to know too much about the server’s internal logic and thresholds.
  • Hardcode in model.py: This works, but it’s poor practice for managing multiple model versions. You’d have to copy and paste configuration settings every time you update the code.
  • Configure via config.pbtxt: This is the best solution. It centralizes all necessary settings in the model’s configuration file, separating logic from deployment settings.

Here is how you pass these custom parameters in the config.pbtxt file:

parameters [
  {
    key: "threshold"
    value: { string_value: "0.3" }
  },
  {
    key: "class_ids"
    value: { string_value: "1, 3" }
  }
]

It’s a bit frustrating, but there isn’t great documentation on how to define these parameters directly in the config.pbtxt file. The only way I was able to figure out this feature was by digging through the source code of the triton-inference-server/common/protobuf/model_config.proto file 6.

model.py

Reading Model Weights and Config

To initialize RF-DETR model with weights and configure inference, we use the following code:

class TritonPythonModel:
    def initialize(self, args: Dict[str, str]) -> None:
        # Path to the repository directory: `models/rfdetr-pytorch`
        model_repository_path: str = args["model_repository"]

        # Initialize the model
        self.model = RFDETRMedium(
            pretrain_weights=f"{model_repository_path}/rf-detr-medium.pth"
        )

        # Load `config.pbtxt` file as Json
        self.model_config = model_config = json.loads(args["model_config"])
        print(f"{model_config=}")

        # Get model parameters.
        # Pay attention that parameters we get from `model_config` are plane strings.
        # We have to parse them manually.
        # Here we parse
        # {
        #     key: "class_ids"
        #     value: { string_value: "1, 3" }
        # }
        self.class_ids: np.ndarray = np.fromstring(
            model_config["parameters"]["class_ids"]["string_value"],
            sep=", ",
            dtype=np.int32,
        )
        # Here we parse
        #{
        #     key: "threshold"
        #     value: { string_value: "0.3" }
        # }
        self.threshold = float(model_config["parameters"]["threshold"]["string_value"])

        # The code below is the same as in the previous examples

        # Get INPUT0 configuration
        input_config = pb_utils.get_input_config_by_name(model_config, "in:jpg")
        assert input_config
        # Convert Triton types to numpy types
        self.input_dtype = pb_utils.triton_string_to_numpy(input_config["data_type"])

        # Get OUTPUT0 configuration
        output_config = pb_utils.get_output_config_by_name(
            model_config, "detections:json"
        )
        assert output_config
        # Convert Triton types to numpy types
        self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])

Batch Inference and Response

On the inference stage (which is in the execute() function), the first thing we have to do is to collect images from the batch request:

images = []
for request in requests:
    input_tensor = pb_utils.get_input_tensor_by_name(request, "in:jpg")
    input_arr: np.ndarray = input_tensor.as_numpy()
    image = cv2.imdecode(input_arr, cv2.IMREAD_UNCHANGED)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    images.append(image)

Next, we feed the batch to the model and get detections:

detections: Detections | List[Detections] = self.model.predict(images, threshold=self.threshold)

RF-DETR model returns Detections for a single image and List[Detections] for a batch. From the detections we simply compose JSON/JSONs and send it as response to the client.

Client Internals

Client side consists of two main parts:

  • Sending an image. Code is the same as in echo image example.
  • Receiving a JSON response. Code is the same as in echo JSON example.

How to Call Sync and Async Versions of the Model?

Sync Client

The gist of the sync client is as follows:

import tritonclient.grpc as grpcclient

grpc_client = grpcclient.InferenceServerClient(url="localhost:8001", verbose=False)

results = grpc_client.infer(
    model_name,
    inputs,
    outputs=outputs,
    model_version="1",
)

We import grpcclient from tritonclient.grpc and explicitly ask grpcclient to send request to the model version 1.

Aync Client

The gist of the async client is as follows:

import tritonclient.grpc.aio as grpcclient

grpc_client = grpcclient.InferenceServerClient(url="localhost:8001", verbose=False)

results = await grpc_client.infer(
    model_name,
    inputs,
    outputs=outputs,
    model_version="2",
)

Here we import grpcclient from tritonclient.grpc.aio and explicitly ask grpcclient to send request to the model version 2. Note that the result of the inference request must be awaited.

Home Assignment

  1. What is the difference between dynamic batcher and sequence batcher? What are the primary use cases for each one?
  2. For the asynchronous client example, try experimenting with Python’s concurrency tools: semaphores, asyncio.create_task(), and the modern asyncio.TaskGroup(). See how each affects the client’s performance and parallelism.
  3. Log the batch size on the server size. Modify your model.py file to log the batch size on the server side. Can you successfully achieve a batch size greater than 1 when using the synchronous version of the model? (Hint: Think about how the client handles requests.)

Wrapping Up

Today, we covered a lot of new and important ground with Triton! We learned how to serve models on a GPU, how to take advantage of dynamic batching, how to pass custom parameters using the config file, and how to run a model in asynchronous mode. This knowledge forms the essential basis for deploying any serious model with Triton.

The only thing left in the series of tutorials is ensemble serving, which we’ll cover next time.

References