diff --git a/docs/en/guides/triton-inference-server.md b/docs/en/guides/triton-inference-server.md index 09f7516b1..0151cc078 100644 --- a/docs/en/guides/triton-inference-server.md +++ b/docs/en/guides/triton-inference-server.md @@ -83,25 +83,34 @@ The Triton Model Repository is a storage location where Triton can access and lo # (Optional) Enable TensorRT for GPU inference # First run will be slow due to TensorRT engine conversion - import json - - data = { - "optimization": { - "execution_accelerators": { - "gpu_execution_accelerator": [ - { - "name": "tensorrt", - "parameters": {"key": "precision_mode", "value": "FP16"}, - "parameters": {"key": "max_workspace_size_bytes", "value": "3221225472"}, - "parameters": {"key": "trt_engine_cache_enable", "value": "1"}, - } - ] - } + data = """ + optimization { + execution_accelerators { + gpu_execution_accelerator { + name: "tensorrt" + parameters { + key: "precision_mode" + value: "FP16" + } + parameters { + key: "max_workspace_size_bytes" + value: "3221225472" + } + parameters { + key: "trt_engine_cache_enable" + value: "1" + } + parameters { + key: "trt_engine_cache_path" + value: "/models/yolo/1" + } } + } } + """ with open(triton_model_path / "config.pbtxt", "w") as f: - json.dump(data, f, indent=4) + f.write(data) ``` ## Running Triton Inference Server @@ -124,7 +133,7 @@ subprocess.call(f"docker pull {tag}", shell=True) # Run the Triton server and capture the container ID container_id = ( subprocess.check_output( - f"docker run -d --rm -v {triton_repo_path}:/models -p 8000:8000 {tag} tritonserver --model-repository=/models", + f"docker run -d --rm --gpus 0 -v {triton_repo_path}:/models -p 8000:8000 {tag} tritonserver --model-repository=/models", shell=True, ) .decode("utf-8") @@ -215,7 +224,7 @@ Setting up [Ultralytics YOLO11](https://docs.ultralytics.com/models/yolov8/) wit container_id = ( subprocess.check_output( - f"docker run -d --rm -v {triton_repo_path}/models -p 8000:8000 {tag} tritonserver --model-repository=/models", + f"docker run -d --rm --gpus 0 -v {triton_repo_path}/models -p 8000:8000 {tag} tritonserver --model-repository=/models", shell=True, ) .decode("utf-8")