67 std::vector<int>& prediction_obs_ids,
68 torch::Tensor* ptr_multi_obstacle_pos,
69 torch::Tensor* ptr_multi_obstacle_pos_step,
70 torch::Tensor* ptr_vector_mask,
71 torch::Tensor* ptr_all_obs_data,
72 torch::Tensor* ptr_all_obs_p_id,
73 torch::Tensor* ptr_multi_obstacle_position);
87 torch::Tensor* ptr_map_data,
88 torch::Tensor* ptr_all_map_p_id,
89 torch::Tensor* ptr_vector_mask);
123 std::vector<int>& prediction_obs_ids,
124 std::vector<TrajectoryPoint>* trajectory_points,
125 std::vector<std::vector<std::pair<double, double>>>* multi_obs_pos,
126 std::vector<std::vector<double>>* multi_obs_position,
127 std::vector<std::pair<double, double>>* all_obs_length,
128 std::vector<std::vector<std::pair<double, double>>>* all_obs_pos_history,
129 std::vector<std::pair<double, double>>* adc_traj_curr_pos,
130 torch::Tensor* vector_mask);
135 std::string
GetName()
override {
return "MULTI_AGENT_EVALUATOR"; }
144 int max_agent_num = 50;
146 int vector_obs_num = 0;
147 bool with_planning_traj =
true;
149 at::Tensor torch_default_output_tensor_;
bool VectornetProcessObstaclePosition(ObstaclesContainer *obstacles_container, const ADCTrajectoryContainer *adc_trajectory_container, std::vector< int > &prediction_obs_ids, torch::Tensor *ptr_multi_obstacle_pos, torch::Tensor *ptr_multi_obstacle_pos_step, torch::Tensor *ptr_vector_mask, torch::Tensor *ptr_all_obs_data, torch::Tensor *ptr_all_obs_p_id, torch::Tensor *ptr_multi_obstacle_position)
Process obstacle position to vector
bool VectornetProcessMapData(FeatureVector *map_feature, PidVector *map_p_id, const int obs_num, torch::Tensor *ptr_map_data, torch::Tensor *ptr_all_map_p_id, torch::Tensor *ptr_vector_mask)
Process map data to vector
bool ExtractObstaclesHistory(ObstaclesContainer *obstacles_container, std::vector< int > &prediction_obs_ids, std::vector< TrajectoryPoint > *trajectory_points, std::vector< std::vector< std::pair< double, double > > > *multi_obs_pos, std::vector< std::vector< double > > *multi_obs_position, std::vector< std::pair< double, double > > *all_obs_length, std::vector< std::vector< std::pair< double, double > > > *all_obs_pos_history, std::vector< std::pair< double, double > > *adc_traj_curr_pos, torch::Tensor *vector_mask)
Extract all obstacles history