Jetson Orin 部署 6D 位姿估计模型实战:从 PyTorch 到 TensorRT 全流程

作者: 侯本洲 | 具身智能算法工程师
标签: #JetsonOrin #TensorRT #6D位姿估计 #模型部署 #边缘计算 #GDRNPP


前言:为什么我们需要在边缘端跑 6D 位姿估计?

在过去的一年里,我一直深耕工业物流拆垛项目。在这个项目中,核心难点不是识别"箱子在哪里"(2D 检测),而是要精确知道"箱子在空间中的具体姿态"(6D Pose),以便机械臂能够精准抓取。

我们选用了 GDRNPP(Geometry-Guided Direct Regression Network Plus Plus)作为基础模型,它在 BOP Challenge 上表现优异,但原版 PyTorch 实现有个致命问题:

在 NVIDIA Jetson AGX Orin 上直接运行 PyTorch 推理,单帧耗时接近 2.5秒。对于流水线上的工业机器人来说,这个延迟是不可接受的。我们的目标是 0.5秒 以内。

经过两周的魔改、剪枝和 TensorRT 加速,我们最终将推理时间压缩到了 0.3秒,实现了近 8倍 的性能提升。这篇文章将复盘整个 Jetson Orin 部署 6D 位姿估计模型 的实战全流程,包含大量代码细节和踩坑经验。


一、 环境准备与模型选型

1.1 硬件平台

  • 设备:NVIDIA Jetson AGX Orin (64GB)
  • 系统:JetPack 6.0 (Ubuntu 22.04)
  • CUDA:12.2
  • TensorRT:8.6.2

1.2 为什么选择 GDRNPP?

在 6D 位姿估计领域,PoseCNN、PVNet 已经略显老态。GDRNPP 通过结合几何特征和直接回归,在遮挡严重的工业混叠场景下(比如乱堆的快递箱)鲁棒性极强。

但是,GDRNPP 的网络结构比较复杂,包含了:

  1. Backbone: ConvNeXt 或 ResNet(我们选用 ResNet-34 以平衡速度)。
  2. Region Extraction: 从 2D 检测框中抠图。
  3. PNP Solver: 这一点最坑,原版代码很多后处理是在 CPU 上跑的。

二、 第一步:模型瘦身与 ONNX 导出

PyTorch 导出 ONNX 是所有部署工作的第一步,也是报错最多的一步。

2.1 修改模型结构(Model Surgery)

GDRNPP 源码中大量使用了动态控制流(if-else)和一些 TensorRT 不支持的 Grid Sample 操作。为了顺利导出,我们需要对模型进行"手术"。

踩坑点 1:去除无关输出
训练时的 Loss 计算分支在推理时毫无用处,必须切掉。

import torch
import torch.nn as nn
from gdrnpp.model import GDRNPP

class GDRNPP_Inference(nn.Module):
    def __init__(self, original_model):
        super().__init__()
        self.backbone = original_model.backbone
        self.rot_head = original_model.rot_head
        self.trans_head = original_model.trans_head
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        features = self.backbone(x)
        rot_pred = self.rot_head(features)
        trans_pred = self.trans_head(features)
        return rot_pred, trans_pred

2.2 导出 ONNX 代码实战

这里一定要注意 opset_version,Jetson Orin 上的 TensorRT 8.6 对 Opset 17 支持较好。

def export_onnx(model, checkpoint_path, output_path):
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    model.load_state_dict(checkpoint['model'])
    model.eval()
    model.cuda()

    dummy_input = torch.randn(1, 3, 256, 256).cuda()

    print("开始导出 ONNX...")
    torch.onnx.export(
        model,
        dummy_input,
        output_path,
        export_params=True,
        opset_version=17,
        do_constant_folding=True,
        input_names=['input_rgb'],
        output_names=['pred_rot', 'pred_trans'],
        dynamic_axes={
            'input_rgb': {0: 'batch_size'},
            'pred_rot': {0: 'batch_size'},
            'pred_trans': {0: 'batch_size'}
        }
    )
    print(f"导出成功:{output_path}")

    import onnx
    from onnxsim import simplify
    onnx_model = onnx.load(output_path)
    model_simp, check = simplify(onnx_model)
    assert check, "Simplified ONNX model could not be validated"
    onnx.save(model_simp, output_path.replace(".onnx", "_sim.onnx"))
    print("ONNX 模型已简化。")

踩坑点 2:GridSample 的对齐问题
PyTorch 的 grid_sample 参数 align_corners 在不同版本 TensorRT 中行为不一致。建议在导出前,尽量用标准的 Resize 操作替代非必要的 GridSample,或者确保 TensorRT 版本是最新的。


三、 第二步:Jetson Orin 上的 TensorRT 转换

拿到了 ONNX,我们进入核心环节:在边缘计算设备上将其编译为 TensorRT Engine。

3.1 使用 trtexec 进行基准测试

/usr/src/tensorrt/bin/trtexec \
  --onnx=gdrnpp_sim.onnx \
  --saveEngine=gdrnpp_fp16.engine \
  --fp16 \
  --workspace=4096 \
  --verbose

关键参数:

  • --fp16必开! Orin 的 Tensor Core 对 FP16 有极强的加速能力,且对位姿估计的精度影响微乎其微(误差 < 1mm)。
  • --workspace:给 TensorRT 分配的显存暂存空间,Orin 内存大,给 4GB 没问题。

3.2 Python API 构建 Engine

import tensorrt as trt

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

def build_engine(onnx_file_path, engine_file_path):
    builder = trt.Builder(TRT_LOGGER)
    network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    config = builder.create_builder_config()
    parser = trt.OnnxParser(network, TRT_LOGGER)

    if builder.platform_has_fast_fp16:
        config.set_flag(trt.BuilderFlag.FP16)
        print("启用 FP16 精度加速")

    with open(onnx_file_path, 'rb') as model:
        if not parser.parse(model.read()):
            for error in range(parser.num_errors):
                print(parser.get_error(error))
            return None

    config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32)

    print("正在构建 TensorRT Engine...")
    serialized_engine = builder.build_serialized_network(network, config)
    
    with open(engine_file_path, "wb") as f:
        f.write(serialized_engine)
    print("Engine 构建完成!")

四、 第三步:推理优化与后处理加速

4.1 使用 CUDA 算子加速前处理

不要用 OpenCV 的 cv2.resize,它们在 CPU 上执行。用 TorchVision 直接在 GPU 上处理:

import torchvision.transforms as T

preprocess = T.Compose([
    T.Resize((256, 256), antialias=True),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def gpu_preprocess(frame_tensor):
    frame_tensor = frame_tensor.float() / 255.0
    return preprocess(frame_tensor)

4.2 异步推理

利用 CUDA Stream 实现数据拷贝与计算的重叠:

class TRTWrapper:
    def __init__(self, engine_path):
        self.logger = trt.Logger(trt.Logger.WARNING)
        with open(engine_path, "rb") as f, trt.Runtime(self.logger) as runtime:
            self.engine = runtime.deserialize_cuda_engine(f.read())
        self.context = self.engine.create_execution_context()
        self.inputs, self.outputs, self.bindings, self.stream = common.allocate_buffers(self.engine)

    def infer(self, input_tensor):
        self.inputs[0].device = input_tensor.data_ptr()
        self.context.execute_async_v2(
            bindings=self.bindings, 
            stream_handle=self.stream.handle
        )
        self.stream.synchronize()
        return [out.host for out in self.outputs]

4.3 PnP 瓶颈优化

GDRNPP 输出的密集对应图需要 PnP 解算 6D 位姿,OpenCV 的 solvePnP 是 CPU 实现,非常慢。我们最终采用了**“网络直接回归位姿 + 关键帧 ICP 精修”**的策略,完美平衡了速度与精度。


五、 性能对比

测试环境:Jetson AGX Orin (30W),输入 640x480,Batch Size = 1。

部署方案 精度 前处理 推理 后处理 总耗时/帧 FPS
PyTorch 原生 FP32 120ms 2100ms 250ms ~2470ms 0.4
ONNX Runtime FP32 80ms 850ms 250ms ~1180ms 0.8
TensorRT (基础) FP32 80ms 180ms 250ms ~510ms 1.9
TensorRT + FP16 + CUDA前处理 FP16 5ms 45ms 250ms ~300ms 3.3

六、 避坑指南

  1. 版本对应:JetPack、CUDA、TensorRT 的版本必须严格对应
  2. 固定 Batch Size:静态 Batch 能让编译器做更激进的层融合优化
  3. 算子不支持:后处理部分果断切掉放到 CPU 或用 CUDA C++ 重写插件
  4. 散热:AGX Orin 全速跑起来很热,务必 sudo jetson_clocks

总结:从 PyTorch 到 TensorRT 的迁移,前期配置繁琐,但带来的性能收益是巨大的。在边缘计算领域,榨干 GPU 的每一滴性能,是算法工程师的浪漫。


如果你对具身智能、Jetson 开发或 TensorRT 部署感兴趣,欢迎评论区交流!

声明:本文代码基于实际项目脱敏修改,仅供参考。转载请注明出处。

Logo

DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。

更多推荐