Apollo 10.0
自动驾驶开放平台
nms.h
浏览该文件的文档.
1/******************************************************************************
2 * Copyright 2018 The Apollo Authors. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *****************************************************************************/
16#pragma once
17
18#include <vector>
19#include <algorithm>
20#include <limits>
21#include <utility>
22
23#include "cyber/common/log.h"
24
25namespace apollo {
26namespace perception {
27namespace base {
28
29
30void GetMaxScoreIndex(const std::vector<float> &scores, const float threshold,
31 const int top_k,
32 std::vector<std::pair<float, int>> *score_index_vec) {
33 ACHECK(score_index_vec != nullptr);
34
35 for (size_t i = 0; i < scores.size(); ++i) {
36 if (scores[i] > threshold) {
37 score_index_vec->emplace_back(scores[i], i);
38 }
39 }
40
41 std::sort(score_index_vec->begin(), score_index_vec->end(),
42 [](const std::pair<float, int> &a, const std::pair<float, int> &b) {
43 return a.first > b.first;
44 });
45
46 if (top_k > 0 && top_k < static_cast<int>(score_index_vec->size())) {
47 score_index_vec->resize(top_k);
48 }
49}
50
51template <typename BoxType>
52void Nms(const std::vector<BoxType> &bboxes, const std::vector<float> &scores,
53 const float score_threshold, const float nms_threshold,
54 const float eta, const int top_k, std::vector<int> *indices,
55 float (*ComputeOverlap)(const BoxType &, const BoxType &),
56 int limit = std::numeric_limits<int>::max) {
57 ACHECK(bboxes.size() == scores.size());
58 ACHECK(indices != nullptr);
59
60 std::vector<std::pair<float, int>> score_index_vec;
61 GetMaxScoreIndex(scores, score_threshold, top_k, &score_index_vec);
62
63 float adaptive_threshold = nms_threshold;
64 indices->clear();
65 for (const std::pair<float, int> &score_index : score_index_vec) {
66 const int idx = score_index.second;
67 bool keep = true;
68 for (const int &kept_idx : *indices) {
69 float overlap = ComputeOverlap(bboxes[idx], bboxes[kept_idx]);
70 if (overlap > adaptive_threshold) {
71 keep = false;
72 break;
73 }
74 }
75
76 if (keep) {
77 indices->push_back(idx);
78 if (static_cast<int>(indices->size()) >= limit) {
79 break;
80 }
81 }
82
83 if (keep && eta < 1 && adaptive_threshold > 0.5) {
84 adaptive_threshold *= eta;
85 }
86 }
87}
88
89} // namespace base
90} // namespace perception
91} // namespace apollo
#define ACHECK(cond)
Definition log.h:80
void Nms(const std::vector< BoxType > &bboxes, const std::vector< float > &scores, const float score_threshold, const float nms_threshold, const float eta, const int top_k, std::vector< int > *indices, float(*ComputeOverlap)(const BoxType &, const BoxType &), int limit=std::numeric_limits< int >::max)
Definition nms.h:52
void GetMaxScoreIndex(const std::vector< float > &scores, const float threshold, const int top_k, std::vector< std::pair< float, int > > *score_index_vec)
Definition nms.h:30
class register implement
Definition arena_queue.h:37