Apollo 10.0
自动驾驶开放平台
project_feature.cc
浏览该文件的文档.
1/******************************************************************************
2 * Copyright 2018 The Apollo Authors. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the License);
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an AS IS BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *****************************************************************************/
17
18#include <map>
19#include <vector>
20
21#include "cyber/common/file.h"
22#include "cyber/common/log.h"
28
29namespace apollo {
30namespace perception {
31namespace camera {
32
34 std::string config_file =
35 GetConfigFile(options.config_path, options.config_file);
36 if (!cyber::common::GetProtoFromFile(config_file, &model_param_)) {
37 AERROR << "Read config failed: " << config_file;
38 return false;
39 }
40
41 AINFO << "Load config Success: " << model_param_.ShortDebugString();
42
43 const auto &model_info = model_param_.info();
44 std::string model_path = GetModelPath(model_info.name());
45 std::string proto_file =
46 GetModelFile(model_path, model_info.proto_file().file());
47 std::string weight_file =
48 GetModelFile(model_path, model_info.weight_file().file());
49
50 // Network input and output names
51 std::vector<std::string> input_names =
52 inference::GetBlobNames(model_info.inputs());
53 std::vector<std::string> output_names =
54 inference::GetBlobNames(model_info.outputs());
55
57 model_info.framework(), proto_file, weight_file, output_names,
58 input_names, model_path));
59 ACHECK(nullptr != net_) << "Failed to init CNNAdapter";
60
61 gpu_id_ = GlobalConfig::Instance()->track_feature_gpu_id;
62 net_->set_gpu_id(gpu_id_);
63 net_->set_max_batch_size(100);
64
65 std::map<std::string, std::vector<int>> shape_map;
66 inference::AddShape(&shape_map, model_info.inputs());
67 inference::AddShape(&shape_map, model_info.outputs());
68
69 if (!net_->Init(shape_map)) {
70 AERROR << model_info.name() << "init failed!";
71 return false;
72 }
73
74 net_->Infer();
75 return true;
76}
77
79 CameraTrackingFrame *frame) {
80 auto input_blob = net_->get_blob(model_param_.info().inputs(0).name());
81 auto output_blob = net_->get_blob(model_param_.info().outputs(0).name());
82 if (frame->detected_objects.empty()) {
83 return true;
84 }
85 input_blob->Reshape(frame->track_feature_blob->shape());
86 cudaMemcpy(
87 input_blob->mutable_gpu_data(), frame->track_feature_blob->gpu_data(),
88 frame->track_feature_blob->count() * sizeof(float), cudaMemcpyDefault);
89
90 cudaDeviceSynchronize();
91 net_->Infer();
92 cudaDeviceSynchronize();
93 frame->track_feature_blob->Reshape(
94 {static_cast<int>(frame->detected_objects.size()), output_blob->shape(1),
95 output_blob->shape(2), output_blob->shape(3)});
96
97 cudaMemcpy(
98 frame->track_feature_blob->mutable_gpu_data(), output_blob->gpu_data(),
99 frame->track_feature_blob->count() * sizeof(float), cudaMemcpyDefault);
100
101 norm_.L2Norm(frame->track_feature_blob.get());
102 return true;
103}
104
106
107} // namespace camera
108} // namespace perception
109} // namespace apollo
#define REGISTER_FEATURE_EXTRACTOR(name)
bool Init(const FeatureExtractorInitOptions &init_options) override
bool Extract(const FeatureExtractorOptions &options, CameraTrackingFrame *frame) override
void L2Norm(base::Blob< float > *input_data)
#define ACHECK(cond)
Definition log.h:80
#define AERROR
Definition log.h:44
#define AINFO
Definition log.h:42
bool GetProtoFromFile(const std::string &file_name, google::protobuf::Message *message)
Parses the content of the file specified by the file_name as a representation of protobufs,...
Definition file.cc:132
std::vector< std::string > GetBlobNames(const google::protobuf::RepeatedPtrField< common::ModelBlob > &model_blobs)
Definition model_util.cc:23
void AddShape(std::map< std::string, std::vector< int > > *shape_map, const google::protobuf::RepeatedPtrField< common::ModelBlob > &model_blobs)
Definition model_util.cc:32
Inference * CreateInferenceByName(const std::string &frame_work, const std::string &proto_file, const std::string &weight_file, const std::vector< std::string > &outputs, const std::vector< std::string > &inputs, const std::string &model_root)
std::string GetModelFile(const std::string &model_name, const std::string &file_name)
Get the model file path by model path and file name
Definition util.cc:55
std::string GetConfigFile(const std::string &config_path, const std::string &config_file)
Definition util.cc:80
std::string GetModelPath(const std::string &model_name)
Get the model path by model name, search from APOLLO_MODEL_PATH
Definition util.cc:44
class register implement
Definition arena_queue.h:37
std::shared_ptr< base::Blob< float > > track_feature_blob