Apollo 11.0
自动驾驶开放平台
semantic_lstm_vehicle_tensorrt.h
浏览该文件的文档.
1/******************************************************************************
2 * Copyright 2023 The Apollo Authors. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *****************************************************************************/
16
17#pragma once
18
19#include <string>
20#include <vector>
21#include <memory>
22#include <mutex>
23#include <map>
24#include <utility>
25
26#include "NvInfer.h"
27#include "NvInferRuntime.h"
28
31
32namespace apollo {
33namespace prediction {
34
36public:
41
48 virtual bool Init();
49
59 virtual bool Inference(
60 const std::vector<void*>& input_buffer,
61 unsigned int input_size,
62 std::vector<void*>* output_buffer,
63 unsigned int output_size);
64
70 virtual bool LoadModel();
71
77 virtual void Destory();
78
79private:
80 nvinfer1::ICudaEngine* engine_ = nullptr;
81 nvinfer1::IRuntime* runtime_ = nullptr;
82 nvinfer1::IExecutionContext* context_ = nullptr;
83 std::vector<void*> buffers_;
84 cudaStream_t stream_;
85
86 // input output name based on onnx model
87 const char* INPUT_IMAGE_NAME = "img_tensor";
88 const char* INPUT_OBSTACLE_POS_NAME = "obstacle_pos";
89 const char* INPUT_OBSTACLE_POS_STEP_NAME = "obstacle_pos_step";
90 const char* OUTPUT_NAME = "predict";
91 // input output index
92 int input_image_index_;
93 int input_obstacle_pos_index_;
94 int input_obstacle_pos_step_index_;
95 int output_index_;
96 std::mutex mtx;
97
98 class RTLogger : public nvinfer1::ILogger {
105 void log(Severity severity, const char* msg) noexcept override {
106 switch (severity) {
107 case Severity::kINTERNAL_ERROR:
108 case Severity::kERROR:
109 AERROR << msg;
110 break;
111 case Severity::kWARNING:
112 AWARN << msg;
113 break;
114 case Severity::kINFO:
115 case Severity::kVERBOSE:
116 ADEBUG << msg;
117 break;
118 default:
119 break;
120 }
121 }
122 } rt_gLogger;
123};
125
126} // namespace prediction
127} // namespace apollo
virtual void Destory()
free all memory requested, gpu or cpu
virtual bool Init()
parse model description class and load the model
virtual bool Inference(const std::vector< void * > &input_buffer, unsigned int input_size, std::vector< void * > *output_buffer, unsigned int output_size)
performing network inference
#define CYBER_PLUGIN_MANAGER_REGISTER_PLUGIN(name, base)
#define ADEBUG
Definition log.h:41
#define AERROR
Definition log.h:44
#define AWARN
Definition log.h:43
class register implement
Definition arena_queue.h:37