Apollo 10.0
自动驾驶开放平台
model_manager.cc
浏览该文件的文档.
1/******************************************************************************
2 * Copyright 2023 The Apollo Authors. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *****************************************************************************/
16
18
19#include <algorithm>
20#include <sstream>
21
22#include "cyber/common/file.h"
25
26namespace apollo {
27namespace prediction {
28
30 PredictionConf prediciton_conf;
32 FLAGS_prediction_conf_file, &prediciton_conf)) {
33 AERROR << "Unable to load prediction conf file: "
34 << FLAGS_prediction_conf_file;
35 ACHECK(false);
36 }
37 evaluator_model_conf_ = prediciton_conf.evaluator_model_conf();
38 models_ = std::map<std::string, std::vector<std::shared_ptr<ModelBase>>>{};
39}
40
42 static const std::string class_namespace = "apollo::prediction::";
43 std::map<std::string, std::vector<Model>> m;
44 for (auto& model_conf : evaluator_model_conf_.model()) {
45 auto key = to_string(model_conf);
46 if (model_conf.priority() == 0) {
47 // priority = 0 means this plugin only applicable to arm arch
48 #ifndef __aarch64__
49 continue;
50 #endif
51 }
52 if (m.find(key) != m.end()) {
53 m[key].push_back(model_conf);
54 } else {
55 m[key] = std::vector<Model>{model_conf};
56 }
57 }
58
59 // sort the model based on priority
60 for (auto& p : m) {
61 sort(m[p.first].begin(), m[p.first].end(),
62 [](Model& a, Model& b){return a.priority() < b.priority();});
63 }
64
65 for (auto& p : m) {
66 auto key = p.first;
67 auto& model_conf_vec = p.second;
68 models_[key] = std::vector<std::shared_ptr<ModelBase>>{};
69 for (auto model_conf : model_conf_vec) {
70 std::string class_name = model_conf.type();
71 if (class_name.find("::") == std::string::npos) {
72 class_name = class_namespace + class_name;
73 }
74 auto model_ptr =
76 ->CreateInstance<ModelBase>(class_name);
77 models_[key].push_back(model_ptr);
78 }
79 }
80
81 return true;
82}
83
84std::string ModelManager::to_string(const Model& model_config) {
85 std::ostringstream key;
86 key << model_config.evaluator_type() << model_config.obstacle_type();
87 key << model_config.backend();
88 return key.str();
89}
90
91std::string ModelManager::to_string(const Model::Backend& backend,
92 const ObstacleConf::EvaluatorType& evaluator_type,
94 obstacle_type) {
95 std::ostringstream key;
96 key << evaluator_type << obstacle_type << backend;
97 return key.str();
98}
99
100std::shared_ptr<ModelBase> ModelManager::SelectModel(
101 const Model::Backend& backend,
102 const ObstacleConf::EvaluatorType& evaluator_type,
104 obstacle_type) {
105 auto key = to_string(backend, evaluator_type, obstacle_type);
106 if (models_.find(key) == models_.end()) {
107 AERROR << "Can't find model with attribute: evaluator_type("
108 << evaluator_type << ") " << "obstacle_type(" << obstacle_type
109 << ") " << "backend(" << backend << ")";
110 ACHECK(false);
111 }
112 return models_[key][0];
113}
114
115} // namespace prediction
116} // namespace apollo
static PluginManager * Instance()
get singleton instance of PluginManager
std::shared_ptr< Base > CreateInstance(const std::string &derived_class)
create plugin instance of derived class based on Base
std::shared_ptr< ModelBase > SelectModel(const Model::Backend &backend, const ObstacleConf::EvaluatorType &evaluator_type, const apollo::perception::PerceptionObstacle::Type &obstacle_type)
select the best model
bool Init()
init model manager add load all defined plugin
#define ACHECK(cond)
Definition log.h:80
#define AERROR
Definition log.h:44
bool GetProtoFromFile(const std::string &file_name, google::protobuf::Message *message)
Parses the content of the file specified by the file_name as a representation of protobufs,...
Definition file.cc:132
class register implement
Definition arena_queue.h:37
optional ObstacleConf::EvaluatorType evaluator_type
optional apollo::perception::PerceptionObstacle::Type obstacle_type
optional EvaluatorModelConf evaluator_model_conf
warm-up function for torch model to avoid first multiple slowly inference