Apollo 11.0
自动驾驶开放平台
multi_agent_evaluator.h
浏览该文件的文档.
1/******************************************************************************
2 * Copyright 2021 The Apollo Authors. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *****************************************************************************/
16
17#pragma once
18
19#include <string>
20#include <utility>
21#include <vector>
22
23#include "torch/extension.h"
24#include "torch/script.h"
29
30namespace apollo {
31namespace prediction {
32
34
36 public:
41
45 virtual ~MultiAgentEvaluator() = default;
46
50 void Clear();
51
65 ObstaclesContainer* obstacles_container,
66 const ADCTrajectoryContainer* adc_trajectory_container,
67 std::vector<int>& prediction_obs_ids,
68 torch::Tensor* ptr_multi_obstacle_pos,
69 torch::Tensor* ptr_multi_obstacle_pos_step,
70 torch::Tensor* ptr_vector_mask,
71 torch::Tensor* ptr_all_obs_data,
72 torch::Tensor* ptr_all_obs_p_id,
73 torch::Tensor* ptr_multi_obstacle_position);
74
84 bool VectornetProcessMapData(FeatureVector *map_feature,
85 PidVector *map_p_id,
86 const int obs_num,
87 torch::Tensor* ptr_map_data,
88 torch::Tensor* ptr_all_map_p_id,
89 torch::Tensor* ptr_vector_mask);
90
96 bool Evaluate(Obstacle* obstacle, ObstaclesContainer* obstacles_container) override;
97
104 bool Evaluate(const ADCTrajectoryContainer* adc_trajectory_container,
105 Obstacle* obstacle,
106 ObstaclesContainer* obstacles_container) override;
107
122 ObstaclesContainer* obstacles_container,
123 std::vector<int>& prediction_obs_ids,
124 std::vector<TrajectoryPoint>* trajectory_points,
125 std::vector<std::vector<std::pair<double, double>>>* multi_obs_pos,
126 std::vector<std::vector<double>>* multi_obs_position,
127 std::vector<std::pair<double, double>>* all_obs_length,
128 std::vector<std::vector<std::pair<double, double>>>* all_obs_pos_history,
129 std::vector<std::pair<double, double>>* adc_traj_curr_pos,
130 torch::Tensor* vector_mask);
131
135 std::string GetName() override { return "MULTI_AGENT_EVALUATOR"; }
136
137 private:
141 void LoadModel();
142
143 private:
144 int max_agent_num = 50;
145 int obs_num = 0;
146 int vector_obs_num = 0;
147 bool with_planning_traj = true;
148 ModelManager model_manager_;
149 at::Tensor torch_default_output_tensor_;
150 Model::Backend device_;
151 VectorNet vector_net_;
152};
153
154} // namespace prediction
155} // namespace apollo
ADC trajectory container
std::string GetName() override
Get the name of evaluator.
bool VectornetProcessObstaclePosition(ObstaclesContainer *obstacles_container, const ADCTrajectoryContainer *adc_trajectory_container, std::vector< int > &prediction_obs_ids, torch::Tensor *ptr_multi_obstacle_pos, torch::Tensor *ptr_multi_obstacle_pos_step, torch::Tensor *ptr_vector_mask, torch::Tensor *ptr_all_obs_data, torch::Tensor *ptr_all_obs_p_id, torch::Tensor *ptr_multi_obstacle_position)
Process obstacle position to vector
bool VectornetProcessMapData(FeatureVector *map_feature, PidVector *map_p_id, const int obs_num, torch::Tensor *ptr_map_data, torch::Tensor *ptr_all_map_p_id, torch::Tensor *ptr_vector_mask)
Process map data to vector
bool ExtractObstaclesHistory(ObstaclesContainer *obstacles_container, std::vector< int > &prediction_obs_ids, std::vector< TrajectoryPoint > *trajectory_points, std::vector< std::vector< std::pair< double, double > > > *multi_obs_pos, std::vector< std::vector< double > > *multi_obs_position, std::vector< std::pair< double, double > > *all_obs_length, std::vector< std::vector< std::pair< double, double > > > *all_obs_pos_history, std::vector< std::pair< double, double > > *adc_traj_curr_pos, torch::Tensor *vector_mask)
Extract all obstacles history
bool Evaluate(Obstacle *obstacle, ObstaclesContainer *obstacles_container) override
Override Evaluate
void Clear()
Clear obstacle feature map
virtual ~MultiAgentEvaluator()=default
Destructor
Prediction obstacle.
Definition obstacle.h:52
std::vector< std::vector< double > > PidVector
Definition vector_net.h:34
std::vector< std::vector< std::vector< double > > > FeatureVector
Definition vector_net.h:33
class register implement
Definition arena_queue.h:37
Define the data container base class