Apollo 10.0
自动驾驶开放平台
tracking_feat_extractor.cc
浏览该文件的文档.
1/******************************************************************************
2 * Copyright 2023 The Apollo Authors. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the License);
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an AS IS BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *****************************************************************************/
17
18#include "cyber/common/file.h"
20
21namespace apollo {
22namespace perception {
23namespace camera {
24
26 const FeatureExtractorInitOptions &options) {
28 std::string config_file =
29 GetConfigFile(options.config_path, options.config_file);
30
31 if (!cyber::common::GetProtoFromFile(config_file, &feat_param)) {
32 AERROR << "Read feature extractor config file failed!";
33 return false;
34 }
35 if (feat_param.extractor_size() != 1) {
36 AERROR << "extractor should be 1";
37 return false;
38 }
39
40 const std::string feat_blob_name = feat_param.feat_blob_name();
41 const auto &feat_blob_shape_pb = feat_param.feat_blob_shape();
42
43 std::vector<int> feat_blob_shape(feat_blob_shape_pb.begin(),
44 feat_blob_shape_pb.end());
46 std::make_shared<apollo::perception::base::Blob<float>>(feat_blob_shape);
47
48 // setup bottom and top
49 int feat_height = feat_blob_->shape(2);
50 int feat_width = feat_blob_->shape(3);
51
53 options.input_height == 0 ? feat_height : options.input_height;
54 input_width_ = options.input_width == 0 ? feat_width : options.input_width;
55
56 CHECK_EQ(input_height_ / feat_height, input_width_ / feat_width)
57 << "Invalid aspect ratio: " << feat_height << "x" << feat_width
58 << " from " << input_height_ << "x" << input_width_;
59 for (int i = 0; i < feat_param.extractor_size(); i++) {
60 switch (feat_param.extractor(i).feat_type()) {
61 case tracking_feature::ExtractorParam_FeatureType_ROIPooling:
63 break;
64 }
65 }
66 if (roi_poolings_.empty()) {
67 AERROR << "no proper extractor";
68 return false;
69 }
70 return true;
71}
72
75 int feat_channel = feat_blob_->shape(1);
76 feat_height_ = feat_blob_->shape(2);
77 feat_width_ = feat_blob_->shape(3);
78
79 std::shared_ptr<FeatureExtractorLayer> feature_extractor_layer_ptr;
80 feature_extractor_layer_ptr.reset(new FeatureExtractorLayer());
81 std::vector<int> shape{1, 5};
82 feature_extractor_layer_ptr->rois_blob.reset(new base::Blob<float>(shape));
83 int pooled_w = param.pooled_w();
84 int pooled_h = param.pooled_h();
85 bool use_floor = param.use_floor();
86 feature_extractor_layer_ptr->pooling_layer.reset(
87 new inference::ROIPoolingLayer<float>(pooled_h, pooled_w, use_floor, 1,
88 feat_channel));
89 feature_extractor_layer_ptr->top_blob.reset(
90 new base::Blob<float>(1, feat_blob_->channels(), pooled_h, pooled_w));
91 roi_poolings_.push_back(feature_extractor_layer_ptr);
92}
93
95 CameraTrackingFrame *frame) {
96 if (frame == nullptr) {
97 return false;
98 }
99 if (frame->detected_objects.empty()) {
100 return true;
101 }
102 if (!options.normalized) {
103 // normalize bbox
104 encode_bbox(&(frame->detected_objects));
105 }
106 auto feat_blob = frame->feature_blob;
107 for (auto feature_extractor_layer_ptr : roi_poolings_) {
108 feature_extractor_layer_ptr->rois_blob->Reshape(
109 {static_cast<int>(frame->detected_objects.size()), 5});
110 float *rois_data =
111 feature_extractor_layer_ptr->rois_blob->mutable_cpu_data();
112 for (const auto &obj : frame->detected_objects) {
113 rois_data[0] = 0;
114 rois_data[1] =
115 obj->camera_supplement.box.xmin * static_cast<float>(feat_width_);
116 rois_data[2] =
117 obj->camera_supplement.box.ymin * static_cast<float>(feat_height_);
118 rois_data[3] =
119 obj->camera_supplement.box.xmax * static_cast<float>(feat_width_);
120 rois_data[4] =
121 obj->camera_supplement.box.ymax * static_cast<float>(feat_height_);
122 ADEBUG << rois_data[0] << " " << rois_data[1] << " " << rois_data[2]
123 << " " << rois_data[3] << " " << rois_data[4];
124 rois_data += feature_extractor_layer_ptr->rois_blob->offset(1);
125 }
126
127 // obtain the track feature blob
128 feature_extractor_layer_ptr->pooling_layer->ForwardGPU(
129 {feat_blob, feature_extractor_layer_ptr->rois_blob},
130 {frame->track_feature_blob});
131
132 if (!options.normalized) {
133 // denormalize bbox
134 decode_bbox(&(frame->detected_objects));
135 }
136 }
137 norm_.L2Norm(frame->track_feature_blob.get());
138 return true;
139}
140
142} // namespace camera
143} // namespace perception
144} // namespace apollo
#define REGISTER_FEATURE_EXTRACTOR(name)
A wrapper around SyncedMemory holders serving as the basic computational unit for images,...
Definition blob.h:88
void encode_bbox(std::vector< std::shared_ptr< base::Object > > *objects)
std::shared_ptr< base::Blob< float > > feat_blob_
void decode_bbox(std::vector< std::shared_ptr< base::Object > > *objects)
bool Extract(const FeatureExtractorOptions &options, CameraTrackingFrame *frame) override
extract feature from frame
void init_roipooling(const tracking_feature::ROIPoolingParam &param)
std::vector< std::shared_ptr< FeatureExtractorLayer > > roi_poolings_
bool Init(const FeatureExtractorInitOptions &init_options) override
init feature extractor
void L2Norm(base::Blob< float > *input_data)
#define ADEBUG
Definition log.h:41
#define AERROR
Definition log.h:44
bool GetProtoFromFile(const std::string &file_name, google::protobuf::Message *message)
Parses the content of the file specified by the file_name as a representation of protobufs,...
Definition file.cc:132
std::string GetConfigFile(const std::string &config_path, const std::string &config_file)
Definition util.cc:80
class register implement
Definition arena_queue.h:37
std::shared_ptr< base::Blob< float > > feature_blob
std::shared_ptr< base::Blob< float > > track_feature_blob