83 {
84
85 CHECK_NOTNULL(obstacle_ptr);
86
88
89 int id = obstacle_ptr->id();
90 if (!obstacle_ptr->latest_feature().IsInitialized()) {
91 AERROR <<
"Obstacle [" <<
id <<
"] has no latest feature.";
92 return false;
93 }
94 Feature* latest_feature_ptr = obstacle_ptr->mutable_latest_feature();
95 CHECK_NOTNULL(latest_feature_ptr);
96
97
98
99
100 std::vector<double> feature_values;
102 if (FLAGS_prediction_offline_mode ==
105 "pedestrian", nullptr);
106 ADEBUG <<
"Saving extracted features for learning locally.";
107 return true;
108 }
109
110 static constexpr double kShortTermPredictionTimeResolution = 0.4;
111 static constexpr int kShortTermPredictionPointNum = 5;
112 static constexpr int kHiddenStateUpdateCycle = 4;
113
114
115 torch::Tensor social_pooling = GetSocialPooling();
116 std::vector<torch::jit::IValue> social_embedding_inputs;
117 social_embedding_inputs.push_back(std::move(social_pooling.to(device_)));
118 torch::Tensor social_embedding =
119 torch_social_embedding_.forward(social_embedding_inputs)
120 .toTensor()
121 .to(torch::kCPU);
122
123
124 double pos_x = feature_values[2];
125 double pos_y = feature_values[3];
126 double rel_x = 0.0;
127 double rel_y = 0.0;
128 if (obstacle_ptr->history_size() > kHiddenStateUpdateCycle - 1) {
129 rel_x = obstacle_ptr->latest_feature().position().x() -
130 obstacle_ptr->feature(3).position().x();
131 rel_y = obstacle_ptr->latest_feature().position().y() -
132 obstacle_ptr->feature(3).position().y();
133 }
134
135 torch::Tensor torch_position = torch::zeros({1, 2});
136 torch_position[0][0] = rel_x;
137 torch_position[0][1] = rel_y;
138 std::vector<torch::jit::IValue> position_embedding_inputs;
139 position_embedding_inputs.push_back(std::move(torch_position.to(device_)));
140 torch::Tensor position_embedding =
141 torch_position_embedding_.forward(position_embedding_inputs)
142 .toTensor()
143 .to(torch::kCPU);
144
145
146 torch::Tensor lstm_input =
147 torch::zeros({1, 2 * (kEmbeddingSize + kHiddenSize)});
148 for (int i = 0; i < kEmbeddingSize; ++i) {
149 lstm_input[0][i] = position_embedding[0][i];
150 }
151
152 if (obstacle_id_lstm_state_map_.find(id) ==
153 obstacle_id_lstm_state_map_.end()) {
154 obstacle_id_lstm_state_map_[id].ht = torch::zeros({1, 1, kHiddenSize});
155 obstacle_id_lstm_state_map_[id].ct = torch::zeros({1, 1, kHiddenSize});
156 obstacle_id_lstm_state_map_[id].timestamp = obstacle_ptr->timestamp();
157 obstacle_id_lstm_state_map_[id].frame_count = 0;
158 }
159 torch::Tensor curr_ht = obstacle_id_lstm_state_map_[id].ht;
160 torch::Tensor curr_ct = obstacle_id_lstm_state_map_[id].ct;
161 int curr_frame_count = obstacle_id_lstm_state_map_[id].frame_count;
162
163 if (curr_frame_count == kHiddenStateUpdateCycle - 1) {
164 for (int i = 0; i < kHiddenSize; ++i) {
165 lstm_input[0][kEmbeddingSize + i] = curr_ht[0][0][i];
166 lstm_input[0][kEmbeddingSize + kHiddenSize + i] = curr_ct[0][0][i];
167 }
168
169 std::vector<torch::jit::IValue> lstm_inputs;
170 lstm_inputs.push_back(std::move(lstm_input.to(device_)));
171 auto lstm_out_tuple = torch_single_lstm_.forward(lstm_inputs).toTuple();
172 auto ht = lstm_out_tuple->elements()[0].toTensor();
173 auto ct = lstm_out_tuple->elements()[1].toTensor();
174 obstacle_id_lstm_state_map_[id].ht = ht.clone();
175 obstacle_id_lstm_state_map_[id].ct = ct.clone();
176 }
177 obstacle_id_lstm_state_map_[id].frame_count =
178 (curr_frame_count + 1) % kHiddenStateUpdateCycle;
179
180
181
182 Trajectory* trajectory = latest_feature_ptr->add_predicted_trajectory();
183 trajectory->set_probability(1.0);
184 TrajectoryPoint* start_point = trajectory->add_trajectory_point();
185 start_point->mutable_path_point()->set_x(pos_x);
186 start_point->mutable_path_point()->set_y(pos_y);
187 start_point->mutable_path_point()->set_theta(latest_feature_ptr->theta());
188 start_point->set_v(latest_feature_ptr->speed());
189 start_point->set_relative_time(0.0);
190
191 for (int i = 1; i <= kShortTermPredictionPointNum; ++i) {
192 double prev_x = trajectory->trajectory_point(i - 1).path_point().x();
193 double prev_y = trajectory->trajectory_point(i - 1).path_point().y();
194 ACHECK(obstacle_id_lstm_state_map_.find(
id) !=
195 obstacle_id_lstm_state_map_.end());
196 torch::Tensor torch_position = torch::zeros({1, 2});
197 double curr_rel_x = rel_x;
198 double curr_rel_y = rel_y;
199 if (i > 1) {
200 curr_rel_x =
201 prev_x - trajectory->trajectory_point(i - 2).path_point().x();
202 curr_rel_y =
203 prev_y - trajectory->trajectory_point(i - 2).path_point().y();
204 }
205 torch_position[0][0] = curr_rel_x;
206 torch_position[0][1] = curr_rel_y;
207 std::vector<torch::jit::IValue> position_embedding_inputs;
208 position_embedding_inputs.push_back(std::move(torch_position.to(device_)));
209 torch::Tensor position_embedding =
210 torch_position_embedding_.forward(position_embedding_inputs)
211 .toTensor()
212 .to(torch::kCPU);
213 torch::Tensor lstm_input =
214 torch::zeros({1, kEmbeddingSize + 2 * kHiddenSize});
215 for (int i = 0; i < kEmbeddingSize; ++i) {
216 lstm_input[0][i] = position_embedding[0][i];
217 }
218
219 auto ht = obstacle_id_lstm_state_map_[id].ht.clone();
220 auto ct = obstacle_id_lstm_state_map_[id].ct.clone();
221
222 for (int i = 0; i < kHiddenSize; ++i) {
223 lstm_input[0][kEmbeddingSize + i] = ht[0][0][i];
224 lstm_input[0][kEmbeddingSize + kHiddenSize + i] = ct[0][0][i];
225 }
226 std::vector<torch::jit::IValue> lstm_inputs;
227 lstm_inputs.push_back(std::move(lstm_input.to(device_)));
228 auto lstm_out_tuple = torch_single_lstm_.forward(lstm_inputs).toTuple();
229 ht = lstm_out_tuple->elements()[0].toTensor();
230 ct = lstm_out_tuple->elements()[1].toTensor();
231 std::vector<torch::jit::IValue> prediction_inputs;
232 prediction_inputs.push_back(ht[0]);
233 auto pred_out_tensor = torch_prediction_layer_.forward(prediction_inputs)
234 .toTensor()
235 .to(torch::kCPU);
236 auto pred_out = pred_out_tensor.accessor<float, 2>();
237 TrajectoryPoint* point = trajectory->add_trajectory_point();
238 double curr_x = prev_x + static_cast<double>(pred_out[0][0]);
239 double curr_y = prev_y + static_cast<double>(pred_out[0][1]);
240 point->mutable_path_point()->set_x(curr_x);
241 point->mutable_path_point()->set_y(curr_y);
242 point->set_v(latest_feature_ptr->speed());
243 point->mutable_path_point()->set_theta(
244 latest_feature_ptr->velocity_heading());
245 point->set_relative_time(kShortTermPredictionTimeResolution *
246 static_cast<double>(i));
247 }
248
249 return true;
250}
static void InsertDataForLearning(const Feature &feature, const std::vector< double > &feature_values, const std::string &category, const LaneSequence *lane_sequence_ptr)
Insert a data_for_learning
bool ExtractFeatures(const Obstacle *obstacle_ptr, std::vector< double > *feature_values)
Extract features for learning model's input
static const int kDumpDataForLearning