Apollo 10.0
自动驾驶开放平台
external_feature_extractor.cc
浏览该文件的文档.
1/******************************************************************************
2 * Copyright 2018 The Apollo Authors. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the License);
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an AS IS BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *****************************************************************************/
17
18#include <map>
19#include <vector>
20
21#include "cyber/common/file.h"
28
29namespace apollo {
30namespace perception {
31namespace camera {
32
34 const FeatureExtractorInitOptions &options) {
35 std::string config_file =
36 GetConfigFile(options.config_path, options.config_file);
37 if (!cyber::common::GetProtoFromFile(config_file, &model_param_)) {
38 AERROR << "Read config failed: " << config_file;
39 return false;
40 }
41
42 AINFO << "Load config Success: " << model_param_.ShortDebugString();
43
44 const auto &model_info = model_param_.info();
45 std::string model_path = GetModelPath(model_info.name());
46 std::string proto_file =
47 GetModelFile(model_path, model_info.proto_file().file());
48 std::string weight_file =
49 GetModelFile(model_path, model_info.weight_file().file());
50
51 // Network input and output names
52 std::vector<std::string> input_names =
53 inference::GetBlobNames(model_info.inputs());
54 std::vector<std::string> output_names =
55 inference::GetBlobNames(model_info.outputs());
56
57 height_ = model_param_.resize().height();
58 width_ = model_param_.resize().width();
59
61 model_info.framework(), proto_file, weight_file, output_names,
62 input_names, model_path));
63 ACHECK(nullptr != net_) << "Failed to init CNNAdapter";
64
65 gpu_id_ = GlobalConfig::Instance()->track_feature_gpu_id;
66 net_->set_gpu_id(gpu_id_);
67
68 std::map<std::string, std::vector<int>> shape_map;
69 inference::AddShape(&shape_map, model_info.inputs());
70 inference::AddShape(&shape_map, model_info.outputs());
71 if (!net_->Init(shape_map)) {
72 AERROR << model_info.name() << "init failed!";
73 return false;
74 }
75
76 net_->Infer();
77
78 InitFeatureExtractor(options);
79 image_.reset(new base::Image8U(height_, width_, base::Color::BGR));
80 return true;
81}
82
83bool ExternalFeatureExtractor::InitFeatureExtractor(
84 const FeatureExtractorInitOptions &options) {
85 FeatureExtractorInitOptions feat_options;
86 feat_options.config_path = options.feature_path;
87 feat_options.config_file = options.feature_file;
88 feat_options.gpu_id = gpu_id_;
89 // feat_blob is the 0th of the output blobs
90 auto feat_blob_name = model_param_.info().outputs(0).name();
91 feat_options.feat_blob = net_->get_blob(feat_blob_name);
92 feat_options.input_height = height_;
93 feat_options.input_width = width_;
94 feature_extractor_.reset(BaseFeatureExtractorRegisterer::GetInstanceByName(
95 "TrackingFeatureExtractor"));
96 feature_extractor_->Init(feat_options);
97 return true;
98}
99
101 CameraTrackingFrame *frame) {
102 int raw_height = frame->data_provider->src_height();
103 int raw_width = frame->data_provider->src_width();
104 auto input_blob = net_->get_blob(model_param_.info().inputs(0).name());
105 DataProvider::ImageOptions image_options;
106 image_options.target_color = base::Color::BGR;
107 auto offset_y_ = static_cast<int>(
108 model_param_.offset_ratio() * static_cast<float>(raw_height) + 0.5f);
109 image_options.crop_roi =
110 base::RectI(0, offset_y_, raw_width, raw_height - offset_y_);
111 image_options.do_crop = true;
112 // Timer timer;
113 frame->data_provider->GetImage(image_options, image_.get());
114 inference::ResizeGPU(*image_, input_blob, raw_width, 0);
115 net_->Infer();
116 FeatureExtractorOptions feat_options;
117 feat_options.normalized = false;
118 feature_extractor_->set_roi(
119 image_options.crop_roi.x, image_options.crop_roi.y,
120 image_options.crop_roi.width, image_options.crop_roi.height);
121 feature_extractor_->Extract(feat_options, frame);
122 AINFO << "Extract Done";
123 return true;
124}
125
127
128} // namespace camera
129} // namespace perception
130} // namespace apollo
#define REGISTER_FEATURE_EXTRACTOR(name)
A wrapper around Blob holders serving as the basic computational unit for images.
Definition image_8u.h:44
bool Extract(const FeatureExtractorOptions &options, CameraTrackingFrame *frame) override
bool Init(const FeatureExtractorInitOptions &init_options) override
#define ACHECK(cond)
Definition log.h:80
#define AERROR
Definition log.h:44
#define AINFO
Definition log.h:42
bool GetProtoFromFile(const std::string &file_name, google::protobuf::Message *message)
Parses the content of the file specified by the file_name as a representation of protobufs,...
Definition file.cc:132
Rect< int > RectI
Definition box.h:159
std::vector< std::string > GetBlobNames(const google::protobuf::RepeatedPtrField< common::ModelBlob > &model_blobs)
Definition model_util.cc:23
void AddShape(std::map< std::string, std::vector< int > > *shape_map, const google::protobuf::RepeatedPtrField< common::ModelBlob > &model_blobs)
Definition model_util.cc:32
Inference * CreateInferenceByName(const std::string &frame_work, const std::string &proto_file, const std::string &weight_file, const std::vector< std::string > &outputs, const std::vector< std::string > &inputs, const std::string &model_root)
bool ResizeGPU(const base::Image8U &src, std::shared_ptr< apollo::perception::base::Blob< float > > dst, int stepwidth, int start_axis)
std::string GetModelFile(const std::string &model_name, const std::string &file_name)
Get the model file path by model path and file name
Definition util.cc:55
std::string GetConfigFile(const std::string &config_path, const std::string &config_file)
Definition util.cc:80
std::string GetModelPath(const std::string &model_name)
Get the model path by model name, search from APOLLO_MODEL_PATH
Definition util.cc:44
class register implement
Definition arena_queue.h:37
std::shared_ptr< camera::DataProvider > data_provider