Apollo 11.0
自动驾驶开放平台
multi_agent_vehicle_tensorrt.h
浏览该文件的文档.
1/******************************************************************************
2 * Copyright 2023 The Apollo Authors. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *****************************************************************************/
16
17#pragma once
18
19#include <string>
20#include <vector>
21#include <memory>
22#include <mutex>
23#include <map>
24#include <utility>
25
26#include "NvInfer.h"
27#include "NvInferRuntime.h"
28
31
32namespace apollo {
33namespace prediction {
34
36 public:
42
50
57 virtual bool Init();
58
68 virtual bool Inference(
69 const std::vector<void*>& input_buffer,
70 unsigned int input_size,
71 std::vector<void*>* output_buffer,
72 unsigned int output_size);
73
79 virtual bool LoadModel();
80
86 virtual void Destory();
87
88 private:
89 nvinfer1::ICudaEngine* engine_ = nullptr;
90 nvinfer1::IRuntime* runtime_ = nullptr;
91 nvinfer1::IExecutionContext* context_ = nullptr;
92 std::vector<void*> buffers_;
93 cudaStream_t stream_;
94
95 int max_agent_num = 50;
96
97 // input output name based on onnx model
98 const char* INPUT_MULTI_OBSTACLE_POS_NAME = "multi_obstacle_pos";
99 const char* INPUT_MULTI_OBSTACLE_POS_STEP_NAME = "multi_obstacle_pos_step";
100 const char* INPUT_VECTOR_DATA_NAME = "vector_data";
101 const char* INPUT_BOOL_VECTOR_MASK_NAME = "bool_vector_mask";
102 const char* INPUT_BOOL_POLYLINE_MASK_NAME = "bool_polyline_mask";
103 const char* INPUT_RAND_MASK_NAME = "rand_mask";
104 const char* INPUT_POLYLINE_ID_NAME = "polyline_id";
105 const char* INPUT_OBS_POSITION_NAME = "obs_position";
106 const char* OUTPUT_NAME = "predict";
107 // input output index
108 int input_multi_obstacle_pos_index_;
109 int input_multi_obstacle_pos_step_index_;
110 int input_vector_data_index_;
111 int input_bool_vector_mask_index_;
112 int input_bool_polyline_mask_index_;
113 int input_rand_mask_index_;
114 int input_polyline_id_index_;
115 int input_obs_position_index_;
116 int output_index_;
117 std::mutex mtx;
118
119 class RTLogger : public nvinfer1::ILogger {
126 void log(Severity severity, const char* msg) noexcept override {
127 switch (severity) {
128 case Severity::kINTERNAL_ERROR:
129 case Severity::kERROR:
130 AERROR << msg;
131 break;
132 case Severity::kWARNING:
133 AWARN << msg;
134 break;
135 case Severity::kINFO:
136 case Severity::kVERBOSE:
137 ADEBUG << msg;
138 break;
139 default:
140 break;
141 }
142 }
143 } rt_gLogger;
144};
147
148} // namespace prediction
149} // namespace apollo
virtual bool Inference(const std::vector< void * > &input_buffer, unsigned int input_size, std::vector< void * > *output_buffer, unsigned int output_size)
performing network inference
virtual void Destory()
free all memory requested, gpu or cpu
MultiAgentVehicleTensorrt()
Construct a new Multi Agent Vectornet Tensorrt object
virtual bool Init()
parse model description class and load the model
~MultiAgentVehicleTensorrt()
Destroy the Multi Agent Vectornet Tensorrt object
#define CYBER_PLUGIN_MANAGER_REGISTER_PLUGIN(name, base)
#define ADEBUG
Definition log.h:41
#define AERROR
Definition log.h:44
#define AWARN
Definition log.h:43
class register implement
Definition arena_queue.h:37