Apollo 10.0
自动驾驶开放平台
caddn_obstacle_detector.cc
浏览该文件的文档.
1/******************************************************************************
2 * Copyright 2023 The Apollo Authors. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the License);
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an AS IS BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *****************************************************************************/
17
18#include <map>
19#include <memory>
20
21#include "opencv2/opencv.hpp"
22
23#include "cyber/common/file.h"
24#include "cyber/common/log.h"
28
29namespace apollo {
30namespace perception {
31namespace camera {
32
33const std::map<std::string, base::ObjectSubType> kKITTIName2SubTypeMap = {
35 {"Pedestrian", base::ObjectSubType::PEDESTRIAN},
37 {"MAX_OBJECT_TYPE", base::ObjectSubType::MAX_OBJECT_TYPE},
38};
39
40void CaddnObstacleDetector::InitImageSize(
41 const caddn::ModelParam &model_param) {
42 // resize
43 auto resize = model_param.resize();
44 if (resize.width() == 0 && resize.height() == 0) {
45 width_ = options_.image_width * resize.fx();
46 height_ = options_.image_height * resize.fy();
47 } else {
48 width_ = resize.width();
49 height_ = resize.height();
50 }
51
52 AINFO << "height=" << height_ << ", "
53 << "width=" << width_;
54}
55
56void CaddnObstacleDetector::InitParam(const caddn::ModelParam &model_param) {
57 // confidence threshold
58 score_threshold_ = model_param.score_threshold();
59}
60
61bool CaddnObstacleDetector::InitTypes(const caddn::ModelParam &model_param) {
62 for (const auto &class_name : model_param.info().class_names()) {
63 if (kKITTIName2SubTypeMap.find(class_name) != kKITTIName2SubTypeMap.end()) {
64 types_.push_back(kKITTIName2SubTypeMap.at(class_name));
65 } else {
66 AERROR << "Unsupported subtype type!" << class_name;
67 return false;
68 }
69 }
70 return true;
71}
72
74 options_ = options;
75 std::string config_file =
76 GetConfigFile(options_.config_path, options_.config_file);
77 if (!cyber::common::GetProtoFromFile(config_file, &model_param_)) {
78 AERROR << "Read model param failed!";
79 return false;
80 }
81
82 InitImageSize(model_param_);
83 InitParam(model_param_);
84 InitTypes(model_param_);
85
86 const auto &model_info = model_param_.info();
87 std::string model_path = GetModelPath(model_info.name());
88 if (!InitNetwork(model_param_.info(), model_path)) {
89 AERROR << "Init network failed!";
90 return false;
91 }
92 return true;
93}
94
95bool CaddnObstacleDetector::Preprocess(const base::Image8U *image,
96 base::BlobPtr<float> input_blob) {
97 cv::Mat img = cv::Mat(image->rows(), image->cols(), CV_8UC3);
98 memcpy(img.data, image->cpu_data(), image->total() * sizeof(uint8_t));
99
100 // resize
101 cv::resize(img, img, cv::Size(width_, height_));
102
103 // mean and std
104 img.convertTo(img, CV_32F, 1.0 / 255, 0);
105 std::vector<float> mean_values{0, 0, 0};
106 std::vector<float> std_values{0.229, 0.224, 0.225};
107
108 std::vector<cv::Mat> rgbChannels(3);
109 cv::split(img, rgbChannels);
110 for (int i = 0; i < 3; ++i) {
111 rgbChannels[i].convertTo(rgbChannels[i], CV_32FC1, 1 / std_values[i],
112 (0.0 - mean_values[i]) / std_values[i]);
113 }
114 cv::merge(rgbChannels, img);
115
116 // from hwc to chw
117 int rows = img.rows;
118 int cols = img.cols;
119 int chs = img.channels();
120
121 // fill input_blob
122 input_blob->Reshape({1, chs, rows, cols});
123 float *input_data = input_blob->mutable_cpu_data();
124 for (int i = 0; i < chs; ++i) {
125 cv::extractChannel(
126 img, cv::Mat(rows, cols, CV_32FC1, input_data + i * rows * cols), i);
127 }
128 return true;
129}
130
132 if (frame == nullptr) {
133 return false;
134 }
135
136 // Inputs
137 auto model_inputs = model_param_.info().inputs();
138 auto input_image_blob = net_->get_blob(model_inputs[0].name());
139 auto input_cam2img_blob = net_->get_blob(model_inputs[1].name());
140 auto input_lidar2cam_blob = net_->get_blob(model_inputs[2].name());
141
142 const auto &camera_k_matrix = options_.intrinsic;
143 float *input_cam_intrinsic = input_cam2img_blob->mutable_cpu_data();
144 for (size_t i = 0; i < 3; ++i) {
145 for (size_t j = 0; j < 4; ++j) {
146 if (3 == j) {
147 input_cam_intrinsic[i * 4 + j] = 0.0;
148 } else {
149 input_cam_intrinsic[i * 4 + j] = camera_k_matrix(i, j);
150 }
151 }
152 }
153
154 float *input_lidar2cam_data = input_lidar2cam_blob->mutable_cpu_data();
155 for (int i = 0; i < 16; ++i) {
156 input_lidar2cam_data[i] = lidar_to_cam_[i];
157 }
158
159 DataProvider::ImageOptions image_options;
160 image_options.target_color = base::Color::BGR;
161 std::shared_ptr<base::Image8U> image = std::make_shared<base::Image8U>();
162 frame->data_provider->GetImage(image_options, image.get());
163
164 Preprocess(image.get(), input_image_blob);
165
166 // Infer
167 net_->Infer();
168
169 // Outputs
170 auto model_outputs = model_param_.info().outputs();
171 auto out_detections = net_->get_blob(model_outputs[0].name());
172 auto out_labels = net_->get_blob(model_outputs[1].name());
173 auto out_scores = net_->get_blob(model_outputs[2].name());
174
175 // todo(daohu527): The caddn model currently does not output tracking features
176 // appearance features for tracking
177 // frame->feature_blob = net_->get_blob(model_outputs[3].name());
178
179 GetCaddnObjects(&frame->detected_objects, model_param_, types_,
180 out_detections, out_labels, out_scores);
181
182 return true;
183}
184
186
187} // namespace camera
188} // namespace perception
189} // namespace apollo
#define REGISTER_OBSTACLE_DETECTOR(name)
A wrapper around Blob holders serving as the basic computational unit for images.
Definition image_8u.h:44
const uint8_t * cpu_data() const
Definition image_8u.h:99
virtual bool InitNetwork(const common::ModelInfo &model_info, const std::string &model_root)
Interface for network initialization
std::shared_ptr< inference::Inference > net_
bool Detect(onboard::CameraFrame *frame) override
Main part to detect obstacle
bool Init(const ObstacleDetectorInitOptions &options=ObstacleDetectorInitOptions()) override
Init function for CaddnObstacleDetector
#define AERROR
Definition log.h:44
#define AINFO
Definition log.h:42
bool GetProtoFromFile(const std::string &file_name, google::protobuf::Message *message)
Parses the content of the file specified by the file_name as a representation of protobufs,...
Definition file.cc:132
std::shared_ptr< Blob< Dtype > > BlobPtr
Definition blob.h:313
const std::map< std::string, base::ObjectSubType > kKITTIName2SubTypeMap
void GetCaddnObjects(std::vector< base::ObjectPtr > *objects, const caddn::ModelParam &model_param, const std::vector< base::ObjectSubType > &types, const base::BlobPtr< float > &boxes, const base::BlobPtr< float > &labels, const base::BlobPtr< float > &scores)
Get the Caddn Objects objects from Blob
std::string GetConfigFile(const std::string &config_path, const std::string &config_file)
Definition util.cc:80
std::string GetModelPath(const std::string &model_name)
Get the model path by model name, search from APOLLO_MODEL_PATH
Definition util.cc:44
class register implement
Definition arena_queue.h:37
std::shared_ptr< camera::DataProvider > data_provider