Apollo 11.0
自动驾驶开放平台
inference_factory.cc
浏览该文件的文档.
1/******************************************************************************
2 * Copyright 2018 The Apollo Authors. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *****************************************************************************/
16
18
23#if GPU_PLATFORM == NVIDIA
25#define RTNET RTNet(proto_file, weight_file, outputs, inputs)
26#define RTNET8 RTNet(proto_file, weight_file, outputs, inputs, model_root)
27#elif GPU_PLATFORM == AMD
29#define RTNET MINet(proto_file, weight_file, outputs, inputs)
30// TODO(B1tway) Add quantization int8 support for RTNetInt8.
31// RTNetInt8 on MIGraphX currently works with fp32.
32#define RTNET8 RTNET
33#endif
34
35namespace apollo {
36namespace perception {
37namespace inference {
38
39Inference *CreateInferenceByName(const std::string &frame_work,
40 const std::string &proto_file,
41 const std::string &weight_file,
42 const std::vector<std::string> &outputs,
43 const std::vector<std::string> &inputs,
44 const std::string &model_root) {
45 if (frame_work == "RTNet") {
46 return new RTNET;
47 } else if (frame_work == "RTNetInt8") {
48 return new RTNET8;
49 } else if (frame_work == "TorchNet") {
50 // PyTorch just have model file, we use proto_file as model file
51 return new TorchNet(proto_file, outputs, inputs);
52 } else if (frame_work == "Obstacle") {
53 return new ObstacleDetector(proto_file, weight_file, outputs, inputs);
54 } else if (frame_work == "Onnx") {
55 return new SingleBatchInference(proto_file, outputs, inputs);
56 } else if (frame_work == "PaddleNet") {
57 return new PaddleNet(proto_file, weight_file, outputs, inputs);
58 }
59 return nullptr;
60}
61
62Inference *CreateInferenceByName(const common::Framework &frame_work,
63 const std::string &proto_file,
64 const std::string &weight_file,
65 const std::vector<std::string> &outputs,
66 const std::vector<std::string> &inputs,
67 const std::string &model_root) {
68 switch (frame_work) {
69 case common::TensorRT:
70 if (model_root.empty()) {
71 return new RTNET;
72 }
73 return new RTNET8;
74 case common::PyTorch:
75 return new TorchNet(proto_file, outputs, inputs);
76 case common::PaddlePaddle:
77 return new PaddleNet(proto_file, weight_file, outputs, inputs);
78 case common::Obstacle:
79 return new ObstacleDetector(proto_file, weight_file, outputs, inputs);
80 case common::Onnx:
81 return new SingleBatchInference(proto_file, outputs, inputs);
82 default:
83 break;
84 }
85 return nullptr;
86}
87
88} // namespace inference
89} // namespace perception
90} // namespace apollo
#define RTNET
#define RTNET8
Inference * CreateInferenceByName(const std::string &frame_work, const std::string &proto_file, const std::string &weight_file, const std::vector< std::string > &outputs, const std::vector< std::string > &inputs, const std::string &model_root)
class register implement
Definition arena_queue.h:37