Apollo 10.0
自动驾驶开放平台
apollo::audio::SirenDetection类 参考

#include <siren_detection.h>

apollo::audio::SirenDetection 的协作图:

Public 成员函数

 SirenDetection ()
 
 ~SirenDetection ()=default
 
bool Evaluate (const std::vector< std::vector< double > > &signals)
 

详细描述

在文件 siren_detection.h32 行定义.

构造及析构函数说明

◆ SirenDetection()

apollo::audio::SirenDetection::SirenDetection ( )

在文件 siren_detection.cc30 行定义.

30 : device_(torch::kCPU) {
31 LoadModel();
32}

◆ ~SirenDetection()

apollo::audio::SirenDetection::~SirenDetection ( )
default

成员函数说明

◆ Evaluate()

bool apollo::audio::SirenDetection::Evaluate ( const std::vector< std::vector< double > > &  signals)

在文件 siren_detection.cc34 行定义.

34 {
35 // Sanity checks.
36 omp_set_num_threads(1);
37 if (signals.size() == 0) {
38 AERROR << "Got no channel in signals!";
39 return false;
40 }
41 if (signals[0].size() == 0) {
42 AERROR << "Got no signal in channel 0!";
43 return false;
44 }
45 if (signals[0].size() != 72000) {
46 AERROR << "signals[0].size() = " << signals[0].size() << ", skiping!";
47 return false;
48 }
49 torch::Tensor audio_tensor = torch::empty(4 * 1 * 72000);
50 float* data = audio_tensor.data_ptr<float>();
51
52 for (const auto& channel : signals) {
53 for (const auto& i : channel) {
54 *data++ = static_cast<float>(i) / 32767.0;
55 }
56 }
57
58 torch::Tensor torch_input = torch::from_blob(audio_tensor.data_ptr<float>(),
59 {4, 1, 72000});
60 std::vector<torch::jit::IValue> torch_inputs;
61 torch_inputs.push_back(torch_input.to(device_));
62
63 auto start_time = std::chrono::system_clock::now();
64 at::Tensor torch_output_tensor = torch_model_.forward(torch_inputs).toTensor()
65 .to(torch::kCPU);
66
67 auto end_time = std::chrono::system_clock::now();
68 std::chrono::duration<double> diff = end_time - start_time;
69 AINFO << "SirenDetection used time: " << diff.count() * 1000 << " ms.";
70 auto torch_output = torch_output_tensor.accessor<float, 2>();
71
72 // majority vote with 4 channels
73 float neg_score = torch_output[0][0] + torch_output[1][0] +
74 torch_output[2][0] + torch_output[3][0];
75 float pos_score = torch_output[0][1] + torch_output[1][1] +
76 torch_output[2][1] + torch_output[3][1];
77 ADEBUG << "neg_score = " << neg_score << ", pos_score = " << pos_score;
78 if (neg_score < pos_score) {
79 return true;
80 } else {
81 return false;
82 }
83}
#define ADEBUG
Definition log.h:41
#define AERROR
Definition log.h:44
#define AINFO
Definition log.h:42

该类的文档由以下文件生成: