The Core Concepts of Triton
In the following sections we’re going to briefly answer on the questions about Triton configuration needed for model serving:
- What needs to be configured?
- Where does the configuration happen?
- How do we actually configure it?
Tensors: The Heart of Triton
So, what needs to be configured?
The entire Triton Inference Server (from here on, I’ll just say “Triton” to refer to both the client and server) is built around a single, powerful idea: tensors. Yep, those very same tensors you’re already familiar with from deep learning for inputs, outputs, and matrix representations.
To handle these tensors on both the server and client sides, Triton relies on NumPy, a package every ML and data science pro knows like the back of their hand.
Each tensor has three essential properties we need to set up:
- Shape: The dimensions of your data (e.g., [1, 224, 224, 3]).
- Data Type: The type of data it holds (e.g., float32, int64).
- Name: This one might seem a bit odd, but think about it: every variable in Python has a name, PyTorch layers have names, and ONNX models have names for their nodes. So, it’s only natural that Triton uses names to identify these tensors.
Keep in mind: At its core, a tensor is just a block of bytes. If your data can be represented as bytes, you can send it to and receive it from Triton.
Model Repository
Where does it need to be configured?
On the server side, all the action takes place in the model repository. You’ll need to set this up to follow a specific, strict layout 1:
- for python:
<model-repository-path>/ <model-name>/ config.pbtxt 1/ model.py
- for onnx:
<model-repository-path>/ <model-name>/ config.pbtxt 1/ model.onnx
A Real-World Example
To put this into perspective, let’s look at a concrete example of a model repository layout:
tree models/
models
└── echo-tensor
├── 1
│ └── model.py
└── config.pbtxt
In this example, echo-tensor
is the name of our model, and 1
is its version. Inside the version folder, you’ll find two crucial files: model.py
, which contains our model’s code, and config.pbtxt
, which holds all the configuration for the Triton server.
Think of config.pbtxt
as a blueprint. It tells the Triton server how to process incoming client requests and how to feed that data to the model defined in model.py
.
This brings us back to our tensors. Their shapes, data types, and names must be configured in two places:
- On the server side, within both the
config.pbtxt
file and themodel.py
script (for Python-based models). - On the client side, when you’re preparing your request.
Tensor Schema
So, how do we configure it?
The easiest way to understand this is to walk through an example.
Example Model
Let’s imagine we have a simple model named “echo-tensor”. This model takes a single input tensor with a shape of [2, 3]
and a data type of float32
. All it does is return the exact same tensor, unchanged, as its output.
The config.pbtxt
File
To get this model running, we need to describe it in the config.pbtxt
file. This file must specify the model name, its inputs, and its outputs:
name: "echo-tensor"
max_batch_size: 0
backend: "python"
input [
{
name: "input:tensor"
data_type: TYPE_FP32
dims: [ 2, 3 ]
}
]
output [
{
name: "output:tensor"
data_type: TYPE_FP32
dims: [ 2, 3 ]
}
]
instance_group [
{
count: 1
kind: KIND_CPU
}
]
The configuration starts with some general settings:
model_name
: A string that uniquely identifies your model. For our example, this is"echo-tensor"
.max_batch_size
: This controls how many requests Triton can process at once. If it’s set to0
, you must specify all dimensions of your input and output tensors. If it’s greater than0
, the first dimension is treated as the batch dimension and should be omitted from your dims list. We’ll dive deeper into this later.backend
: This tells Triton what kind of model you’re using (e.g., Python, ONNX, TensorFlow). You can find a list of all available backends here2.instance_group
: This section allocates the hardware resources for your model. It lets you specify whether to use a CPU (KIND_CPU
) or a GPU (KIND_GPU
) and how many instances to run.
When you define your input and output tensors, you’ll see how we use name, data_type, and dims to describe each one. The syntax allows for multiple inputs and outputs, and each can have its own distinct properties. You can find a complete list of supported data types here 3.
Table: Correspondence between data types in config.pbtxt and models
Model Config
TensorRT
ONNX Runtime
PyTorch
API
NumPy
TYPE_BOOL
kBOOL
BOOL
kBool
BOOL
bool
TYPE_UINT8
kUINT8
UINT8
kByte
UINT8
uint8
TYPE_UINT16
UINT16
UINT16
uint16
TYPE_UINT32
UINT32
UINT32
uint32
TYPE_UINT64
UINT64
UINT64
uint64
TYPE_INT8
kINT8
INT8
kChar
INT8
int8
TYPE_INT16
INT16
kShort
INT16
int16
TYPE_INT32
kINT32
INT32
kInt
INT32
int32
TYPE_INT64
kINT64
INT64
kLong
INT64
int64
TYPE_FP16
kHALF
FLOAT16
FP16
float16
TYPE_FP32
kFLOAT
FLOAT
kFloat
FP32
float32
TYPE_FP64
DOUBLE
kDouble
FP64
float64
TYPE_STRING
STRING
BYTES
dtype(object)
TYPE_BF16
kBF16
BF16
The Client Side
Once your model is configured on the server, the next step is to set up the client to match the schema defined in your config.pbtxt
file:
# client-echo-tensor.py
import numpy as np
import tritonclient.grpc as grpcclient
import tritonclient.utils as utils
triton_server_url = "localhost:8001"
model_name = "echo-tensor"
input_tensor_name = "input:tensor"
output_tensor_name = "output:tensor"
# 01. Make a client
grpc_client = grpcclient.InferenceServerClient(url=triton_server_url, verbose=False)
# 02. Prepare data for the model
# with a proper shape and type
np_data = np.array(
[
[1, 2, 3],
[4, 5, 6],
],
dtype=np.float32,
)
# 03. Convert data to the Triton internal representation:
# - set inputs
# - set outputs
inputs = [
grpcclient.InferInput(
input_tensor_name, np_data.shape, utils.np_to_triton_dtype(np_data.dtype)
)
]
inputs[0].set_data_from_numpy(np_data)
outputs = [
grpcclient.InferRequestedOutput(output_tensor_name),
]
# 04. Make a request to the server
results = grpc_client.infer(model_name, inputs, outputs=outputs)
# 05. Convert results back to numpy
response_arr = results.as_numpy("output:tensor")
Triton provides four types of clients to communicate with the server:
- HTTP Client:
import tritonclient.http as httpclient
- Asynchronous HTTP Client:
import tritonclient.http.aio as httpclient
- gRPC Client:
import tritonclient.grpc as grpcclient
- Asynchronous gRPC Client:
import tritonclient.grpc.aio as grpcclient
All four clients share the same API, so you can switch between them easily. We’ll explore the asynchronous clients in more detail later on, as they’re great for improving performance.
Setting Up the Server
Now that we’ve covered the client, let’s turn our attention to the server. The first step is to organize our model repository according to the model name we’ve chosen:
tree models/
models
└── echo-tensor
├── 1
│ └── model.py
└── config.pbtxt
The model.py
File
We’ve already prepared the content for config.pbtxt
, so now it’s time to create our model’s logic in the model.py
script.
Every Python-based Triton model must follow a specific template. This structure ensures that Triton can properly load your model, initialize it, and run inference:
from typing import Dict, List
import triton_python_backend_utils as pb_utils
class TritonPythonModel:
def initialize(self, args: Dict[str, str]) -> None:
"""
This function is optional. Executed only once when model is loaded.
Parameters
----------
args : dict
Both keys and values are strings. The dictionary keys and values are:
* model_config: A JSON string containing the model configuration
* model_instance_kind: A string containing model instance kind
* model_instance_device_id: A string containing model instance device ID
* model_repository: Model repository path
* model_version: Model version
* model_name: Model name
"""
...
def execute(self, requests: "List[pb_utils.InferenceRequest]") -> "List[pb_utils.InferenceResponse]":
"""
The function is called when inference request is made.
For each `request` in `requests` response must be prepared.
Basically, len(requests) == batch_size.
If there is an error, you can set the error argument
when creating a pb_utils.InferenceResponse:
# pb_utils.InferenceResponse(
# output_tensors=..., TritonError("An error occurred"))
"""
...
def finalize(self) -> None:
"""
This function is optional. Executed only once when model is unloaded.
"""
...
Now, with that template in mind, let’s write the code for our “echo-tensor” model example:
# model.py
import json
from typing import Dict, List
import numpy as np
import triton_python_backend_utils as pb_utils
class TritonPythonModel:
def initialize(self, args: Dict[str, str]) -> None:
# NOTE: here we're loading config from `config.pbtxt` file
# and getting all necessary data:
# - tensor data type
# - tensor shape
self.model_config = json.loads(args["model_config"])
output_config = pb_utils.get_output_config_by_name(
self.model_config, "output:tensor"
)
self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])
self.output_shape = output_config["dims"]
def execute(self, requests: "List[pb_utils.InferenceRequest]") -> "List[pb_utils.InferenceResponse]":
responses = []
for request in requests:
# Get input by name
input_tensor = pb_utils.get_input_tensor_by_name(request, "input:tensor")
input_arr: np.ndarray = input_tensor.as_numpy()
# Got input numpy array: `input_arr`.
# Process `input_arr` here
# [...]
# print(f"Got np.array: {input_arr}", flush=True)
# Prepare output by name
outputs = input_arr
out_tensor = pb_utils.Tensor(
"output:tensor", outputs.astype(self.output_dtype)
)
inference_response = pb_utils.InferenceResponse(output_tensors=[out_tensor])
responses.append(inference_response)
return responses
Running Your Model
We’ve now described our input and output tensors in the model script, the config.pbtxt
file, and the client code. All that’s left is to deploy the model and run our first inference.
Starting the Server
To get your model up and running on the server, just use a simple Docker command:
docker run --gpus=1 --rm --net=host -v ${PWD}/models:/models nvcr.io/nvidia/tritonserver:24.08-py3 tritonserver --model-repository=/models
If everything loads correctly, you’ll see a log that looks something like this, indicating your model is ready for requests:
+--------------------+---------+--------+
| Model | Version | Status |
+--------------------+---------+--------+
| echo-tensor | 1 | READY |
+--------------------+---------+--------+
You can even use curl
to send a quick request to check the model’s status and make sure it’s alive and well:
curl -X POST localhost:8000/v2/repository/index | jq
[
{
"name": "echo-tensor",
"version": "1",
"state": "READY"
}
]
Running Inference from the Client
Finally, to send a request to your model from the client side, just run your script:
python client-echo-tensor.py
If all goes according to plan, the response you get back will be the exact same tensor you sent. You’ve successfully completed your first end-to-end inference with Triton!
For full code of the echo-tensor
example, please visit https://github.com/yevhen-k/triton-tutorials and check the following files and directories:
client-echo-tensor.py
models/echo-tensor
Home Assignment
As a home assignment, just play around with the code we’ve made:
- Change data type
- Change input/output shapes
- Change model name
- Change input/output names
- Make transrofmations of the input tensor, for example
- Double it
- Reshape it
- Try to make a model with two inputs and two outputs and return sum and difference between the two inputs
For more tutorials and examples, visit:
triton-inference-server/tutorials
4triton-inference-server/python_backend/examples
5triton-inference-server/server
6
Wrapping It Up
Getting a Triton server up and running can feel a bit complex at first, with a lot of small details to keep track of. But we did it! We successfully set up a basic echo server and learned how to send and receive tensors – the fundamental building blocks of Triton.
Throughout this series, we discovered that every input and output tensor must be precisely defined by its name, shape, and data type within the crucial config.pbtxt
file.
By the end of this tutorial, you should have a solid understanding of a minimal project structure and feel confident in your ability to prepare data for inference, execute the inference, and handle the response.
References
-
Model Files: https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/model_repository.html#model-files ↩︎
-
Where can I find all the backends that are available for Triton? https://github.com/triton-inference-server/backend?tab=readme-ov-file#where-can-i-find-all-the-backends-that-are-available-for-triton ↩︎
-
Triton Data Types: https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/model_configuration.html#datatypes ↩︎
-
Tutorials and examples for Triton Inference Server: https://github.com/triton-inference-server/tutorials ↩︎
-
Python Triton Backend Examples: https://github.com/triton-inference-server/python_backend/tree/main/examples ↩︎
-
Triton Inference Server: https://github.com/triton-inference-server/server ↩︎