Apollo 10.0
自动驾驶开放平台
fused_classifier.cc
浏览该文件的文档.
1/******************************************************************************
2 * Copyright 2023 The Apollo Authors. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *****************************************************************************/
16
18
19#include <vector>
20
21#include "cyber/common/file.h"
24
25namespace apollo {
26namespace perception {
27namespace radar4d {
28
29using ObjectPtr = std::shared_ptr<apollo::perception::base::Object>;
31
33 std::string config_file =
34 GetConfigFile(options.config_path, options.config_file);
36 ACHECK(cyber::common::GetProtoFromFile(config_file, &config));
37 temporal_window_ = config.temporal_window();
38 enable_temporal_fusion_ = config.enable_temporal_fusion();
39 use_tracked_objects_ = config.use_tracked_objects();
40 one_shot_fusion_method_ = config.one_shot_fusion_method();
41 sequence_fusion_method_ = config.sequence_fusion_method();
42 one_shot_fuser_ = BaseOneShotTypeFusionRegisterer::GetInstanceByName(
43 one_shot_fusion_method_);
44
45 bool init_success = true;
46 init_option_.config_path = options.config_path;
47 CHECK_NOTNULL(one_shot_fuser_);
48 ACHECK(one_shot_fuser_->Init(init_option_));
49 sequence_fuser_ = BaseSequenceTypeFusionRegisterer::GetInstanceByName(
50 sequence_fusion_method_);
51 CHECK_NOTNULL(sequence_fuser_);
52 ACHECK(sequence_fuser_->Init(init_option_));
53 return init_success;
54}
55
57 RadarFrame* frame) {
58 if (frame == nullptr) {
59 return false;
60 }
61 std::vector<ObjectPtr>* objects = use_tracked_objects_
62 ? &(frame->tracked_objects)
63 : &(frame->segmented_objects);
64 if (enable_temporal_fusion_ && frame->timestamp > 0.0) {
65 // sequence fusion
66 AINFO << "Combined classifier, temporal fusion";
67 sequence_.AddTrackedFrameObjects(*objects, frame->timestamp);
68 ObjectSequence::TrackedObjects tracked_objects;
69 for (auto& object : *objects) {
70 if (object->radar4d_supplement.is_background) {
71 object->type_probs.assign(static_cast<int>(ObjectType::MAX_OBJECT_TYPE),
72 0);
73 object->type = ObjectType::UNKNOWN_UNMOVABLE;
74 object->type_probs[static_cast<int>(ObjectType::UNKNOWN_UNMOVABLE)] =
75 1.0;
76 continue;
77 }
78 const int track_id = object->track_id;
79 sequence_.GetTrackInTemporalWindow(track_id, &tracked_objects,
80 temporal_window_);
81 if (tracked_objects.empty()) {
82 AERROR << "Find zero-length track, so skip.";
83 continue;
84 }
85 if (object != tracked_objects.rbegin()->second) {
86 AERROR << "There must exist some timestamp in disorder, so skip.";
87 continue;
88 }
89 if (!sequence_fuser_->TypeFusion(option_, &tracked_objects)) {
90 AERROR << "Failed to fuse types, so break.";
91 break;
92 }
93 }
94 } else {
95 // one shot fusion
96 AINFO << "Combined classifier, one shot fusion";
97 for (auto& object : *objects) {
98 if (object->radar4d_supplement.is_background) {
99 object->type_probs.assign(static_cast<int>(ObjectType::MAX_OBJECT_TYPE),
100 0);
101 object->type = ObjectType::UNKNOWN_UNMOVABLE;
102 object->type_probs[static_cast<int>(ObjectType::UNKNOWN_UNMOVABLE)] =
103 1.0;
104 continue;
105 }
106 if (!one_shot_fuser_->TypeFusion(option_, object)) {
107 AERROR << "Failed to fuse types, so continue.";
108 }
109 }
110 }
111 return true;
112}
113
115
116} // namespace radar4d
117} // namespace perception
118} // namespace apollo
virtual bool Init(const TypeFusionInitOption &option)=0
Init type fusion
virtual bool TypeFusion(const TypeFusionOption &option, std::shared_ptr< perception::base::Object > object)=0
Type fusion
virtual bool TypeFusion(const TypeFusionOption &option, TrackedObjects *tracked_objects)=0
Type fusion
virtual bool Init(const TypeFusionInitOption &option)=0
Init type fusion
bool Classify(const ClassifierOptions &options, RadarFrame *frame) override
Classify objects and update type info
bool Init(const ClassifierInitOptions &options=ClassifierInitOptions()) override
Init fused classifier
bool AddTrackedFrameObjects(const std::vector< std::shared_ptr< perception::base::Object > > &objects, TimeStampKey timestamp)
std::map< TimeStampKey, std::shared_ptr< apollo::perception::base::Object > > TrackedObjects
bool GetTrackInTemporalWindow(TrackIdKey track_id, TrackedObjects *track, TimeStampKey window_time)
#define PERCEPTION_REGISTER_CLASSIFIER(name)
#define ACHECK(cond)
Definition log.h:80
#define AERROR
Definition log.h:44
#define AINFO
Definition log.h:42
bool GetProtoFromFile(const std::string &file_name, google::protobuf::Message *message)
Parses the content of the file specified by the file_name as a representation of protobufs,...
Definition file.cc:132
std::shared_ptr< apollo::perception::base::Object > ObjectPtr
std::string GetConfigFile(const std::string &config_path, const std::string &config_file)
Definition util.cc:80
class register implement
Definition arena_queue.h:37
std::vector< std::shared_ptr< base::Object > > tracked_objects
Definition radar_frame.h:49
std::vector< std::shared_ptr< base::Object > > segmented_objects
Definition radar_frame.h:47