Apollo 11.0
自动驾驶开放平台
siren_detection.cc
浏览该文件的文档.
1/******************************************************************************
2 * Copyright 2020 The Apollo Authors. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *****************************************************************************/
16
18
19#include <utility>
20#include <vector>
21
22#include <omp.h>
23
24#include "cyber/common/log.h"
26
27namespace apollo {
28namespace audio {
29
30SirenDetection::SirenDetection() : device_(torch::kCPU) {
31 LoadModel();
32}
33
34bool SirenDetection::Evaluate(const std::vector<std::vector<double>>& signals) {
35 // Sanity checks.
36 omp_set_num_threads(1);
37 if (signals.size() == 0) {
38 AERROR << "Got no channel in signals!";
39 return false;
40 }
41 if (signals[0].size() == 0) {
42 AERROR << "Got no signal in channel 0!";
43 return false;
44 }
45 if (signals[0].size() != 72000) {
46 AERROR << "signals[0].size() = " << signals[0].size() << ", skiping!";
47 return false;
48 }
49 torch::Tensor audio_tensor = torch::empty(4 * 1 * 72000);
50 float* data = audio_tensor.data_ptr<float>();
51
52 for (const auto& channel : signals) {
53 for (const auto& i : channel) {
54 *data++ = static_cast<float>(i) / 32767.0;
55 }
56 }
57
58 torch::Tensor torch_input = torch::from_blob(audio_tensor.data_ptr<float>(),
59 {4, 1, 72000});
60 std::vector<torch::jit::IValue> torch_inputs;
61 torch_inputs.push_back(torch_input.to(device_));
62
63 auto start_time = std::chrono::system_clock::now();
64 at::Tensor torch_output_tensor = torch_model_.forward(torch_inputs).toTensor()
65 .to(torch::kCPU);
66
67 auto end_time = std::chrono::system_clock::now();
68 std::chrono::duration<double> diff = end_time - start_time;
69 AINFO << "SirenDetection used time: " << diff.count() * 1000 << " ms.";
70 auto torch_output = torch_output_tensor.accessor<float, 2>();
71
72 // majority vote with 4 channels
73 float neg_score = torch_output[0][0] + torch_output[1][0] +
74 torch_output[2][0] + torch_output[3][0];
75 float pos_score = torch_output[0][1] + torch_output[1][1] +
76 torch_output[2][1] + torch_output[3][1];
77 ADEBUG << "neg_score = " << neg_score << ", pos_score = " << pos_score;
78 if (neg_score < pos_score) {
79 return true;
80 } else {
81 return false;
82 }
83}
84
85void SirenDetection::LoadModel() {
86 if (torch::cuda::is_available()) {
87 AINFO << "CUDA is available";
88 device_ = torch::Device(torch::kCUDA);
89 }
90 torch::set_num_threads(1);
91 torch_model_ = torch::jit::load(FLAGS_torch_siren_detection_model, device_);
92}
93
94} // namespace audio
95} // namespace apollo
bool Evaluate(const std::vector< std::vector< double > > &signals)
#define ADEBUG
Definition log.h:41
#define AERROR
Definition log.h:44
#define AINFO
Definition log.h:42
class register implement
Definition arena_queue.h:37