69 const std::vector<void*>& input_buffer,
70 unsigned int input_size,
71 std::vector<void*>* output_buffer,
72 unsigned int output_size);
89 nvinfer1::ICudaEngine* engine_ =
nullptr;
90 nvinfer1::IRuntime* runtime_ =
nullptr;
91 nvinfer1::IExecutionContext* context_ =
nullptr;
92 std::vector<void*> buffers_;
95 int max_agent_num = 50;
98 const char* INPUT_MULTI_OBSTACLE_POS_NAME =
"multi_obstacle_pos";
99 const char* INPUT_MULTI_OBSTACLE_POS_STEP_NAME =
"multi_obstacle_pos_step";
100 const char* INPUT_VECTOR_DATA_NAME =
"vector_data";
101 const char* INPUT_BOOL_VECTOR_MASK_NAME =
"bool_vector_mask";
102 const char* INPUT_BOOL_POLYLINE_MASK_NAME =
"bool_polyline_mask";
103 const char* INPUT_RAND_MASK_NAME =
"rand_mask";
104 const char* INPUT_POLYLINE_ID_NAME =
"polyline_id";
105 const char* INPUT_OBS_POSITION_NAME =
"obs_position";
106 const char* OUTPUT_NAME =
"predict";
108 int input_multi_obstacle_pos_index_;
109 int input_multi_obstacle_pos_step_index_;
110 int input_vector_data_index_;
111 int input_bool_vector_mask_index_;
112 int input_bool_polyline_mask_index_;
113 int input_rand_mask_index_;
114 int input_polyline_id_index_;
115 int input_obs_position_index_;
119 class RTLogger :
public nvinfer1::ILogger {
126 void log(Severity severity,
const char* msg)
noexcept override {
128 case Severity::kINTERNAL_ERROR:
129 case Severity::kERROR:
132 case Severity::kWARNING:
135 case Severity::kINFO:
136 case Severity::kVERBOSE: