19#include "yaml-cpp/yaml.h"
25#if __has_include("torch/version.h")
26#include "torch/version.h"
33using torch::indexing::None;
34using torch::indexing::Slice;
41 std::vector<std::vector<double>>&& channels_vec,
42 const std::string& respeaker_extrinsic_file,
const int sample_rate,
43 const double mic_distance) {
44 if (!respeaker2imu_ptr_.get()) {
45 respeaker2imu_ptr_.reset(
new Eigen::Matrix4d);
46 LoadExtrinsics(respeaker_extrinsic_file, respeaker2imu_ptr_.get());
49 EstimateDirection(move(channels_vec), sample_rate, mic_distance);
50 Eigen::Vector4d source_position(kDistance * sin(degree),
51 kDistance * cos(degree), 0, 1);
52 source_position = (*respeaker2imu_ptr_) * source_position;
55 source_position_p3d.set_x(source_position[0]);
56 source_position_p3d.set_y(source_position[1]);
57 source_position_p3d.set_z(source_position[2]);
58 degree = NormalizeAngle(degree);
59 return {source_position_p3d, degree};
62double DirectionDetection::EstimateDirection(
63 std::vector<std::vector<double>>&& channels_vec,
const int sample_rate,
64 const double mic_distance) {
65 std::vector<torch::Tensor> channels_ts;
66 auto options = torch::TensorOptions().dtype(torch::kFloat64);
67 int size =
static_cast<int>(channels_vec[0].size());
68 for (
auto& signal : channels_vec) {
69 channels_ts.push_back(torch::from_blob(signal.data(), {size}, options));
73 double theta0, theta1;
74 const double max_tau = mic_distance / kSoundSpeed;
75 tau0 = GccPhat(channels_ts[0], channels_ts[2], sample_rate, max_tau, 1);
76 theta0 = asin(tau0 / max_tau) * 180 / M_PI;
77 tau1 = GccPhat(channels_ts[1], channels_ts[3], sample_rate, max_tau, 1);
78 theta1 = asin(tau1 / max_tau) * 180 / M_PI;
81 if (fabs(theta0) < fabs(theta1)) {
82 best_guess = theta1 > 0 ? std::fmod(theta0 + 360, 360) : (180 - theta0);
84 best_guess = theta0 < 0 ? std::fmod(theta1 + 360, 360) : (180 - theta1);
85 best_guess = (best_guess + 90 + 180) % 360;
87 best_guess = (-best_guess + 480) % 360;
89 return static_cast<double>(best_guess) / 180 * M_PI;
92bool DirectionDetection::LoadExtrinsics(
const std::string& yaml_file,
93 Eigen::Matrix4d* respeaker_extrinsic) {
95 AINFO << yaml_file <<
" does not exist!";
98 YAML::Node node = YAML::LoadFile(yaml_file);
108 AINFO <<
"Load " << yaml_file <<
" failed! please check!";
111 qw = node[
"transform"][
"rotation"][
"w"].as<
double>();
112 qx = node[
"transform"][
"rotation"][
"x"].as<
double>();
113 qy = node[
"transform"][
"rotation"][
"y"].as<
double>();
114 qz = node[
"transform"][
"rotation"][
"z"].as<
double>();
115 tx = node[
"transform"][
"translation"][
"x"].as<
double>();
116 ty = node[
"transform"][
"translation"][
"y"].as<
double>();
117 tz = node[
"transform"][
"translation"][
"z"].as<
double>();
118 }
catch (YAML::Exception& e) {
119 AERROR <<
"load camera extrinsic file " << yaml_file
120 <<
" with error, YAML exception:" << e.what();
123 respeaker_extrinsic->setConstant(0);
124 Eigen::Quaterniond q;
129 (*respeaker_extrinsic).block<3, 3>(0, 0) = q.normalized().toRotationMatrix();
130 (*respeaker_extrinsic)(0, 3) = tx;
131 (*respeaker_extrinsic)(1, 3) = ty;
132 (*respeaker_extrinsic)(2, 3) = tz;
133 (*respeaker_extrinsic)(3, 3) = 1;
137double DirectionDetection::GccPhat(
const torch::Tensor& sig,
138 const torch::Tensor& refsig,
int fs,
139 double max_tau,
int interp) {
140 const int n_sig = sig.size(0), n_refsig = refsig.size(0),
141 n = n_sig + n_refsig;
142 torch::Tensor psig = at::constant_pad_nd(sig, {0, n_refsig}, 0);
143 torch::Tensor prefsig = at::constant_pad_nd(refsig, {0, n_sig}, 0);
144#if TORCH_VERSION_MINOR <= 7
145 psig = at::rfft(psig, 1,
false,
true);
146 prefsig = at::rfft(prefsig, 1,
false,
true);
148 auto psig_complex = at::fft_rfft(psig, c10::nullopt, -1, c10::nullopt);
149 psig = at::stack({torch::real(psig_complex), torch::imag(psig_complex)}, -1);
151 auto prefsig_complex = at::fft_rfft(prefsig, c10::nullopt, -1, c10::nullopt);
153 {torch::real(prefsig_complex), torch::imag(prefsig_complex)}, -1);
156 ConjugateTensor(&prefsig);
157 torch::Tensor r = ComplexMultiply(psig, prefsig);
158#if TORCH_VERSION_MINOR <= 7
160 at::irfft(r / ComplexAbsolute(r), 1,
false,
true, {interp * n});
162 auto irfft_input_transpose = at::transpose(r / ComplexAbsolute(r), 0, 1);
164 torch::complex(irfft_input_transpose[0], irfft_input_transpose[1]);
166 torch::real(torch::fft::irfft(irfft_complex, n, -1, c10::nullopt));
168 int max_shift =
static_cast<int>(interp * n / 2);
170 max_shift = std::min(
static_cast<int>(interp * fs * max_tau), max_shift);
172 auto begin = cc.index({Slice(cc.size(0) - max_shift, None)});
173 auto end = cc.index({Slice(None, max_shift + 1)});
174 cc = at::cat({begin, end});
176 const int shift = at::argmax(at::abs(cc), 0).item<
int>() - max_shift;
177 const double tau = shift /
static_cast<double>(interp * fs);
182void DirectionDetection::ConjugateTensor(torch::Tensor* tensor) {
183 tensor->index_put_({
"...", 1}, -tensor->index({
"...", 1}));
186torch::Tensor DirectionDetection::ComplexMultiply(
const torch::Tensor& a,
187 const torch::Tensor& b) {
188 torch::Tensor real = a.index({
"...", 0}) * b.index({
"...", 0}) -
189 a.index({
"...", 1}) * b.index({
"...", 1});
190 torch::Tensor imag = a.index({
"...", 0}) * b.index({
"...", 1}) +
191 a.index({
"...", 1}) * b.index({
"...", 0});
192 return at::cat({real.reshape({-1, 1}), imag.reshape({-1, 1})}, 1);
195torch::Tensor DirectionDetection::ComplexAbsolute(
const torch::Tensor& tensor) {
196 torch::Tensor res = tensor * tensor;
197 res = at::sqrt(res.sum(1)).reshape({-1, 1});
std::pair< Point3D, double > EstimateSoundSource(std::vector< std::vector< double > > &&channels_vec, const std::string &respeaker_extrinsic_file, const int sample_rate, const double mic_distance)
Math-related util functions.
double NormalizeAngle(const double angle)
Normalize angle to [-PI, PI).
bool PathExists(const std::string &path)
Check if the path exists.