Apollo 11.0
自动驾驶开放平台
semantic_lstm_vehicle_tensorrt.cc
浏览该文件的文档.
1/******************************************************************************
2 * Copyright 2023 The Apollo Authors. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *****************************************************************************/
16
18
19#include <cxxabi.h>
20
21#include <fstream>
22
23#include <cuda_runtime_api.h>
24
25#include "modules/prediction/proto/prediction_conf.pb.h"
26
27#include "cyber/common/file.h"
28
29namespace apollo {
30namespace prediction {
31
33 std::ifstream engine_file(model_path_, std::ios::binary);
34 engine_file.seekg(0, std::ifstream::end);
35 int64_t engine_size = engine_file.tellg();
36 ACHECK(engine_size > 0);
37 engine_file.seekg(0, std::ifstream::beg);
38
39 void* engine_blob = malloc(engine_size);
40 engine_file.read(reinterpret_cast<char*>(engine_blob), engine_size);
41
42 runtime_ = nvinfer1::createInferRuntime(rt_gLogger);
43 if (!runtime_) {
44 AERROR << "create runtime failed";
45 return false;
46 }
47
48 engine_ = runtime_->deserializeCudaEngine(engine_blob, engine_size);
49 if (!engine_) {
50 AERROR << "create engine failed";
51 return false;
52 }
53
54 context_ = engine_->createExecutionContext();
55 if (!context_) {
56 AERROR << "create context failed";
57 return false;
58 }
59 // check model input output
60 ACHECK(engine_->getNbBindings() == 4);
61
62 input_image_index_ = engine_->getBindingIndex(INPUT_IMAGE_NAME);
63 input_obstacle_pos_index_ = engine_->getBindingIndex(INPUT_OBSTACLE_POS_NAME);
64 input_obstacle_pos_step_index_ = engine_->getBindingIndex(INPUT_OBSTACLE_POS_STEP_NAME);
65 output_index_ = engine_->getBindingIndex(OUTPUT_NAME);
66
67 buffers_ = std::vector<void*>{nullptr, nullptr, nullptr, nullptr};
68
69 // size based on the shape of model input and output
70 cudaMalloc(&buffers_[input_image_index_], 1 * 3 * 224 * 224 * sizeof(float));
71 cudaMalloc(&buffers_[input_obstacle_pos_index_], 1 * 20 * 2 * sizeof(float));
72 cudaMalloc(&buffers_[input_obstacle_pos_step_index_], 1 * 20 * 2 * sizeof(float));
73 cudaMalloc(&buffers_[output_index_], 1 * 30 * 2 * sizeof(float));
74
75 cudaStreamCreate(&stream_);
76
77 return true;
78}
79
81 const std::vector<void*>& input_buffer,
82 unsigned int input_size,
83 std::vector<void*>* output_buffer,
84 unsigned int output_size) {
85 ACHECK(input_size == input_buffer.size() && input_size == 3);
86 ACHECK(output_size == output_buffer->size() && output_size == 1);
87
88 if (init_ == 0) {
89 Init();
90 }
91
92 // ensure thread safe for inference
93 std::lock_guard<std::mutex> lck(mtx);
94
95 cudaMemcpyAsync(
96 buffers_[input_image_index_],
97 input_buffer[0],
98 1 * 3 * 224 * 224 * sizeof(float),
99 cudaMemcpyHostToDevice,
100 stream_);
101 cudaMemcpyAsync(
102 buffers_[input_obstacle_pos_index_],
103 input_buffer[1],
104 1 * 20 * 2 * sizeof(float),
105 cudaMemcpyHostToDevice,
106 stream_);
107 cudaMemcpyAsync(
108 buffers_[input_obstacle_pos_step_index_],
109 input_buffer[2],
110 1 * 20 * 2 * sizeof(float),
111 cudaMemcpyHostToDevice,
112 stream_);
113 if (!context_->enqueueV2(buffers_.data(), stream_, nullptr)) {
114 return false;
115 }
116 cudaMemcpyAsync(
117 (*output_buffer)[0], buffers_[output_index_], 1 * 30 * 2 * sizeof(float), cudaMemcpyDeviceToHost, stream_);
118 cudaStreamSynchronize(stream_);
119 return true;
120}
121
123 cudaStreamDestroy(stream_);
124 cudaFree(buffers_[input_image_index_]);
125 cudaFree(buffers_[input_obstacle_pos_index_]);
126 cudaFree(buffers_[input_obstacle_pos_step_index_]);
127 cudaFree(buffers_[output_index_]);
128}
129
131 // hook: Apollo License Verification: v_apollo_park
132 ModelConf model_config;
133 int status = 0;
134
135 if (init_ != 0) {
136 return true;
137 }
138
139 std::string class_name = abi::__cxa_demangle(typeid(*this).name(), 0, 0, &status);
140 std::string default_config_path
142 class_name, "conf/default_conf.pb.txt");
143
144 if (!cyber::common::GetProtoFromFile(default_config_path, &model_config)) {
145 AERROR << "Unable to load model conf file: " << default_config_path;
146 return false;
147 }
148 model_path_ = model_config.model_path();
149 init_ = 1;
150
151 return LoadModel();
152}
153
154} // namespace prediction
155} // namespace apollo
std::string GetPluginConfPath(const std::string &class_name, const std::string &conf_name)
get plugin configuration file location
static PluginManager * Instance()
get singleton instance of PluginManager
virtual void Destory()
free all memory requested, gpu or cpu
virtual bool Init()
parse model description class and load the model
virtual bool Inference(const std::vector< void * > &input_buffer, unsigned int input_size, std::vector< void * > *output_buffer, unsigned int output_size)
performing network inference
#define ACHECK(cond)
Definition log.h:80
#define AERROR
Definition log.h:44
bool GetProtoFromFile(const std::string &file_name, google::protobuf::Message *message)
Parses the content of the file specified by the file_name as a representation of protobufs,...
Definition file.cc:132
class register implement
Definition arena_queue.h:37