30 std::ifstream engine_file(
model_path_, std::ios::binary);
31 engine_file.seekg(0, std::ifstream::end);
32 int64_t engine_size = engine_file.tellg();
34 engine_file.seekg(0, std::ifstream::beg);
36 void* engine_blob = malloc(engine_size);
37 engine_file.read(
reinterpret_cast<char*
>(engine_blob), engine_size);
39 runtime_ = nvinfer1::createInferRuntime(rt_gLogger);
41 AERROR <<
"create runtime failed";
45 engine_ = runtime_->deserializeCudaEngine(engine_blob, engine_size);
47 AERROR <<
"create engine failed";
51 context_ = engine_->createExecutionContext();
53 AERROR <<
"create context failed";
57 ACHECK(engine_->getNbBindings() == 9);
59 input_multi_obstacle_pos_index_ =
60 engine_->getBindingIndex(INPUT_MULTI_OBSTACLE_POS_NAME);
61 input_multi_obstacle_pos_step_index_ =
62 engine_->getBindingIndex(INPUT_MULTI_OBSTACLE_POS_STEP_NAME);
63 input_vector_data_index_ =
64 engine_->getBindingIndex(INPUT_VECTOR_DATA_NAME);
65 input_bool_vector_mask_index_ =
66 engine_->getBindingIndex(INPUT_BOOL_VECTOR_MASK_NAME);
67 input_bool_polyline_mask_index_ =
68 engine_->getBindingIndex(INPUT_BOOL_POLYLINE_MASK_NAME);
69 input_rand_mask_index_ =
70 engine_->getBindingIndex(INPUT_RAND_MASK_NAME);
71 input_polyline_id_index_ =
72 engine_->getBindingIndex(INPUT_POLYLINE_ID_NAME);
73 input_obs_position_index_ =
74 engine_->getBindingIndex(INPUT_OBS_POSITION_NAME);
76 engine_->getBindingIndex(OUTPUT_NAME);
78 buffers_ = std::vector<void*>{
nullptr,
nullptr,
nullptr,
nullptr,
79 nullptr,
nullptr,
nullptr,
nullptr,
nullptr};
82 cudaMalloc(&buffers_[input_multi_obstacle_pos_index_],
83 1 * max_agent_num * 20 * 2 *
sizeof(
float));
84 cudaMalloc(&buffers_[input_multi_obstacle_pos_step_index_],
85 1 * max_agent_num * 20 * 2 *
sizeof(
float));
86 cudaMalloc(&buffers_[input_vector_data_index_],
87 1 * 450 * 50 * 9 *
sizeof(
float));
88 cudaMalloc(&buffers_[input_bool_vector_mask_index_],
89 1 * 450 * 50 *
sizeof(
bool));
90 cudaMalloc(&buffers_[input_bool_polyline_mask_index_],
91 1 * 450 *
sizeof(
bool));
92 cudaMalloc(&buffers_[input_rand_mask_index_],
93 1 * 450 *
sizeof(
bool));
94 cudaMalloc(&buffers_[input_polyline_id_index_],
95 1 * 450 * 2 *
sizeof(
float));
96 cudaMalloc(&buffers_[input_obs_position_index_],
97 1 * max_agent_num * 3 *
sizeof(
float));
98 cudaMalloc(&buffers_[output_index_],
99 1 * max_agent_num * 30 * 2 *
sizeof(
float));
101 cudaStreamCreate(&stream_);
107 const std::vector<void*>& input_buffer,
108 unsigned int input_size,
109 std::vector<void*>* output_buffer,
110 unsigned int output_size) {
111 ACHECK(input_size == input_buffer.size() && input_size == 8);
112 ACHECK(output_size == output_buffer->size() && output_size == 1);
119 std::lock_guard<std::mutex> lck(mtx);
122 buffers_[input_multi_obstacle_pos_index_],
124 1 * max_agent_num * 20 * 2 *
sizeof(
float),
125 cudaMemcpyHostToDevice,
128 buffers_[input_multi_obstacle_pos_step_index_],
130 1 * max_agent_num * 20 * 2 *
sizeof(
float),
131 cudaMemcpyHostToDevice,
134 buffers_[input_vector_data_index_],
136 1 * 450 * 50 * 9 *
sizeof(
float),
137 cudaMemcpyHostToDevice,
140 buffers_[input_bool_vector_mask_index_],
142 1 * 450 * 50 *
sizeof(
bool),
143 cudaMemcpyHostToDevice,
146 buffers_[input_bool_polyline_mask_index_],
148 1 * 450 *
sizeof(
bool),
149 cudaMemcpyHostToDevice,
152 buffers_[input_rand_mask_index_],
154 1 * 450 *
sizeof(
bool),
155 cudaMemcpyHostToDevice,
158 buffers_[input_polyline_id_index_],
160 1 * 450 * 2 *
sizeof(
float),
161 cudaMemcpyHostToDevice,
164 buffers_[input_obs_position_index_],
166 1 * max_agent_num * 3 *
sizeof(
float),
167 cudaMemcpyHostToDevice,
169 if (!context_->enqueueV2(buffers_.data(), stream_,
nullptr)) {
173 (*output_buffer)[0], buffers_[output_index_],
174 1 * max_agent_num * 30 * 2 *
sizeof(
float),
175 cudaMemcpyDeviceToHost,
177 cudaStreamSynchronize(stream_);
182 cudaStreamDestroy(stream_);
183 cudaFree(buffers_[input_multi_obstacle_pos_index_]);
184 cudaFree(buffers_[input_multi_obstacle_pos_step_index_]);
185 cudaFree(buffers_[input_vector_data_index_]);
186 cudaFree(buffers_[input_bool_vector_mask_index_]);
187 cudaFree(buffers_[input_bool_polyline_mask_index_]);
188 cudaFree(buffers_[input_rand_mask_index_]);
189 cudaFree(buffers_[input_polyline_id_index_]);
190 cudaFree(buffers_[input_obs_position_index_]);
191 cudaFree(buffers_[output_index_]);