Apollo 11.0
自动驾驶开放平台
gated_hungarian_bigraph_matcher.h
浏览该文件的文档.
1/******************************************************************************
2 * Copyright 2018 The Apollo Authors. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *****************************************************************************/
16
17#pragma once
18
19#include <algorithm>
20#include <functional>
21#include <map>
22#include <utility>
23#include <vector>
24
25#include "cyber/common/log.h"
26
29
30namespace apollo {
31namespace perception {
32namespace algorithm {
33
34template <typename T>
36 public:
37 enum class OptimizeFlag { OPTMAX, OPTMIN };
38
39 explicit GatedHungarianMatcher(int max_matching_size = 1000) {
40 global_costs_.Reserve(max_matching_size, max_matching_size);
41 optimizer_.costs()->Reserve(max_matching_size, max_matching_size);
42 }
44
45 /* @brief: global_costs is the memory we reserved for the updating of
46 * costs of matching. it could & need be updated outside the matcher,
47 * before each matching. use it carefully, and make sure all the
48 * elements of the global_costs is updated as you presumed. resize it
49 * and update it completely is STRONG RECOMMENDED!! P.S. resizing SecureMat
50 * would not alloc new memory, if the resizing size is smaller than the
51 * size reserved. */
52 const SecureMat<T>& global_costs() const { return global_costs_; }
53 SecureMat<T>* mutable_global_costs() { return &global_costs_; }
54
55 void Match(T cost_thresh, OptimizeFlag opt_flag,
56 std::vector<std::pair<size_t, size_t>>* assignments,
57 std::vector<size_t>* unassigned_rows,
58 std::vector<size_t>* unassigned_cols);
59
60 void Match(T cost_thresh, T bound_value, OptimizeFlag opt_flag,
61 std::vector<std::pair<size_t, size_t>>* assignments,
62 std::vector<size_t>* unassigned_rows,
63 std::vector<size_t>* unassigned_cols);
64
65 private:
66 /* Step 1:
67 * a. get number of rows & cols
68 * b. determine function of comparison */
69 void MatchInit();
70
71 /* Step 2:
72 * to acclerate matching process, split input cost graph into several
73 * small sub-parts. */
74 void ComputeConnectedComponents(
75 std::vector<std::vector<size_t>>* row_components,
76 std::vector<std::vector<size_t>>* col_components) const;
77
78 /* Step 3:
79 * optimize single connected component, which is part of the global one */
80 void OptimizeConnectedComponent(const std::vector<size_t>& row_component,
81 const std::vector<size_t>& col_component);
82
83 /* Step 4:
84 * generate the set of unassigned row or col index. */
85 void GenerateUnassignedData(std::vector<size_t>* unassigned_rows,
86 std::vector<size_t>* unassigned_cols) const;
87
88 /* @brief: core function for updating the local cost matrix from global one,
89 * we get queryed local costs and write them in the memeory of costs of
90 * optimizer directly
91 * @params[IN] row_component: the set of index of rows of sub-graph
92 * @params[IN] col_component: the set of index of cols of sub-graph
93 * @return: nothing */
94 void UpdateGatingLocalCostsMat(const std::vector<size_t>& row_component,
95 const std::vector<size_t>& col_component);
96
97 void OptimizeAdapter(
98 std::vector<std::pair<size_t, size_t>>* local_assignments);
99
100 /* Hungarian optimizer */
101 HungarianOptimizer<T> optimizer_;
102
103 /* global costs matrix */
104 SecureMat<T> global_costs_;
105
106 /* input data */
107 T cost_thresh_ = 0.0;
108 T bound_value_ = 0.0;
110
111 /* output data */
112 mutable std::vector<std::pair<size_t, size_t>>* assignments_ptr_ = nullptr;
113
114 /* size of component */
115 size_t rows_num_ = 0;
116 size_t cols_num_ = 0;
117
118 /* the rhs is always better than lhs */
119 std::function<bool(T, T)> compare_fun_;
120 std::function<bool(T)> is_valid_cost_;
121}; // class GatedHungarianMatcher
122
123template <typename T>
125 T cost_thresh, OptimizeFlag opt_flag,
126 std::vector<std::pair<size_t, size_t>>* assignments,
127 std::vector<size_t>* unassigned_rows,
128 std::vector<size_t>* unassigned_cols) {
129 Match(cost_thresh, cost_thresh, opt_flag, assignments, unassigned_rows,
130 unassigned_cols);
131}
132
133template <typename T>
135 T cost_thresh, T bound_value, OptimizeFlag opt_flag,
136 std::vector<std::pair<size_t, size_t>>* assignments,
137 std::vector<size_t>* unassigned_rows,
138 std::vector<size_t>* unassigned_cols) {
139 CHECK_NOTNULL(assignments);
140 CHECK_NOTNULL(unassigned_rows);
141 CHECK_NOTNULL(unassigned_cols);
142
143 /* initialize matcher */
144 cost_thresh_ = cost_thresh;
145 opt_flag_ = opt_flag;
146 bound_value_ = bound_value;
147 assignments_ptr_ = assignments;
148 MatchInit();
149
150 /* compute components */
151 std::vector<std::vector<size_t>> row_components;
152 std::vector<std::vector<size_t>> col_components;
153 this->ComputeConnectedComponents(&row_components, &col_components);
154 CHECK_EQ(row_components.size(), col_components.size());
155
156 /* compute assignments */
157 assignments_ptr_->clear();
158 assignments_ptr_->reserve(std::max(rows_num_, cols_num_));
159 for (size_t i = 0; i < row_components.size(); ++i) {
160 this->OptimizeConnectedComponent(row_components[i], col_components[i]);
161 }
162
163 this->GenerateUnassignedData(unassigned_rows, unassigned_cols);
164}
165
166template <typename T>
168 /* get number of rows & cols */
169 rows_num_ = global_costs_.height();
170 cols_num_ = (rows_num_ == 0) ? 0 : global_costs_.width();
171
172 /* determine function of comparison */
173 static std::map<OptimizeFlag, std::function<bool(T, T)>> compare_fun_map = {
174 {OptimizeFlag::OPTMAX, std::less<T>()},
175 {OptimizeFlag::OPTMIN, std::greater<T>()},
176 };
177 auto find_ret = compare_fun_map.find(opt_flag_);
178 ACHECK(find_ret != compare_fun_map.end());
179 compare_fun_ = find_ret->second;
180 is_valid_cost_ = std::bind1st(compare_fun_, cost_thresh_);
181
182 /* check the validity of bound_value */
183 ACHECK(!is_valid_cost_(bound_value_));
184}
185
186template <typename T>
187void GatedHungarianMatcher<T>::ComputeConnectedComponents(
188 std::vector<std::vector<size_t>>* row_components,
189 std::vector<std::vector<size_t>>* col_components) const {
190 CHECK_NOTNULL(row_components);
191 CHECK_NOTNULL(col_components);
192
193 std::vector<std::vector<int>> nb_graph;
194 nb_graph.resize(rows_num_ + cols_num_);
195 for (size_t i = 0; i < rows_num_; ++i) {
196 for (size_t j = 0; j < cols_num_; ++j) {
197 if (is_valid_cost_(global_costs_(i, j))) {
198 nb_graph[i].push_back(static_cast<int>(rows_num_) + j);
199 nb_graph[j + rows_num_].push_back(i);
200 }
201 }
202 }
203
204 std::vector<std::vector<int>> components;
205 ConnectedComponentAnalysis(nb_graph, &components);
206 row_components->clear();
207 row_components->resize(components.size());
208 col_components->clear();
209 col_components->resize(components.size());
210 for (size_t i = 0; i < components.size(); ++i) {
211 for (size_t j = 0; j < components[i].size(); ++j) {
212 int id = components[i][j];
213 if (id < static_cast<int>(rows_num_)) {
214 row_components->at(i).push_back(id);
215 } else {
216 id -= static_cast<int>(rows_num_);
217 col_components->at(i).push_back(id);
218 }
219 }
220 }
221}
222
223template <typename T>
224void GatedHungarianMatcher<T>::OptimizeConnectedComponent(
225 const std::vector<size_t>& row_component,
226 const std::vector<size_t>& col_component) {
227 size_t local_rows_num = row_component.size();
228 size_t local_cols_num = col_component.size();
229
230 /* simple case 1: no possible matches */
231 if (!local_rows_num || !local_cols_num) {
232 return;
233 }
234 /* simple case 2: 1v1 pair with no ambiguousness */
235 if (local_rows_num == 1 && local_cols_num == 1) {
236 size_t idx_r = row_component[0];
237 size_t idx_c = col_component[0];
238 if (is_valid_cost_(global_costs_(idx_r, idx_c))) {
239 assignments_ptr_->push_back(std::make_pair(idx_r, idx_c));
240 }
241 return;
242 }
243
244 /* update local cost matrix */
245 UpdateGatingLocalCostsMat(row_component, col_component);
246
247 /* get local assignments */
248 std::vector<std::pair<size_t, size_t>> local_assignments;
249 OptimizeAdapter(&local_assignments);
250
251 /* parse local assginments into global ones */
252 for (size_t i = 0; i < local_assignments.size(); ++i) {
253 auto local_assignment = local_assignments[i];
254 size_t global_row_idx = row_component[local_assignment.first];
255 size_t global_col_idx = col_component[local_assignment.second];
256 if (!is_valid_cost_(global_costs_(global_row_idx, global_col_idx))) {
257 continue;
258 }
259 assignments_ptr_->push_back(std::make_pair(global_row_idx, global_col_idx));
260 }
261}
262
263template <typename T>
264void GatedHungarianMatcher<T>::GenerateUnassignedData(
265 std::vector<size_t>* unassigned_rows,
266 std::vector<size_t>* unassigned_cols) const {
267 CHECK_NOTNULL(unassigned_rows);
268 CHECK_NOTNULL(unassigned_cols);
269
270 const auto assignments = *assignments_ptr_;
271 unassigned_rows->clear(), unassigned_rows->reserve(rows_num_);
272 unassigned_cols->clear(), unassigned_cols->reserve(cols_num_);
273 std::vector<bool> row_assignment_flags(rows_num_, false);
274 std::vector<bool> col_assignment_flags(cols_num_, false);
275 for (const auto& assignment : assignments) {
276 row_assignment_flags[assignment.first] = true;
277 col_assignment_flags[assignment.second] = true;
278 }
279 for (size_t i = 0; i < row_assignment_flags.size(); ++i) {
280 if (!row_assignment_flags[i]) {
281 unassigned_rows->push_back(i);
282 }
283 }
284 for (size_t i = 0; i < col_assignment_flags.size(); ++i) {
285 if (!col_assignment_flags[i]) {
286 unassigned_cols->push_back(i);
287 }
288 }
289}
290
291template <typename T>
292void GatedHungarianMatcher<T>::UpdateGatingLocalCostsMat(
293 const std::vector<size_t>& row_component,
294 const std::vector<size_t>& col_component) {
295 /* set the invalid cost to bound value */
296 SecureMat<T>* local_costs = optimizer_.costs();
297 local_costs->Resize(row_component.size(), col_component.size());
298 for (size_t i = 0; i < row_component.size(); ++i) {
299 for (size_t j = 0; j < col_component.size(); ++j) {
300 T& current_cost = global_costs_(row_component[i], col_component[j]);
301 if (is_valid_cost_(current_cost)) {
302 (*local_costs)(i, j) = current_cost;
303 } else {
304 (*local_costs)(i, j) = bound_value_;
305 }
306 }
307 }
308}
309
310template <typename T>
311void GatedHungarianMatcher<T>::OptimizeAdapter(
312 std::vector<std::pair<size_t, size_t>>* local_assignments) {
313 CHECK_NOTNULL(local_assignments);
314 if (opt_flag_ == OptimizeFlag::OPTMAX) {
315 optimizer_.Maximize(local_assignments);
316 } else {
317 optimizer_.Minimize(local_assignments);
318 }
319}
320
321} // namespace algorithm
322} // namespace perception
323} // namespace apollo
void Match(T cost_thresh, T bound_value, OptimizeFlag opt_flag, std::vector< std::pair< size_t, size_t > > *assignments, std::vector< size_t > *unassigned_rows, std::vector< size_t > *unassigned_cols)
void Match(T cost_thresh, OptimizeFlag opt_flag, std::vector< std::pair< size_t, size_t > > *assignments, std::vector< size_t > *unassigned_rows, std::vector< size_t > *unassigned_cols)
#define ACHECK(cond)
Definition log.h:80
void ConnectedComponentAnalysis(const std::vector< std::vector< int > > &graph, std::vector< std::vector< int > > *components)
class register implement
Definition arena_queue.h:37