Apollo 10.0
自动驾驶开放平台
apollo::prediction::MultiAgentVehicleCpuTorch类 参考

#include <multi_agent_vehicle_torch_model.h>

类 apollo::prediction::MultiAgentVehicleCpuTorch 继承关系图:
apollo::prediction::MultiAgentVehicleCpuTorch 的协作图:

Public 成员函数

 MultiAgentVehicleCpuTorch ()
 Construct a new Multi Agent Vehicle Cpu Torch object
 
 ~MultiAgentVehicleCpuTorch ()
 Destroy the Multi Agent Vehicle Cpu Torch object
 
virtual bool Init ()
 parse model description class and load the model
 
virtual bool Inference (const std::vector< void * > &input_buffer, unsigned int input_size, std::vector< void * > *output_buffer, unsigned int output_size)
 performing network inference
 
virtual bool LoadModel ()
 load the model from file
 
virtual void Destory ()
 free all memory requested, gpu or cpu
 
- Public 成员函数 继承自 apollo::prediction::ModelBase
 ModelBase ()
 
 ~ModelBase ()
 

额外继承的成员函数

- Public 属性 继承自 apollo::prediction::ModelBase
std::string model_path_
 
uint8_t init_ = 0
 

详细描述

在文件 multi_agent_vehicle_torch_model.h30 行定义.

构造及析构函数说明

◆ MultiAgentVehicleCpuTorch()

apollo::prediction::MultiAgentVehicleCpuTorch::MultiAgentVehicleCpuTorch ( )
inline

Construct a new Multi Agent Vehicle Cpu Torch object

在文件 multi_agent_vehicle_torch_model.h36 行定义.

36{}

◆ ~MultiAgentVehicleCpuTorch()

apollo::prediction::MultiAgentVehicleCpuTorch::~MultiAgentVehicleCpuTorch ( )
inline

Destroy the Multi Agent Vehicle Cpu Torch object

在文件 multi_agent_vehicle_torch_model.h41 行定义.

41{Destory();}
virtual void Destory()
free all memory requested, gpu or cpu

成员函数说明

◆ Destory()

void apollo::prediction::MultiAgentVehicleCpuTorch::Destory ( )
virtual

free all memory requested, gpu or cpu

返回
memory release result, true for success

实现了 apollo::prediction::ModelBase.

在文件 multi_agent_vehicle_torch_model.cc147 行定义.

147{}

◆ Inference()

bool apollo::prediction::MultiAgentVehicleCpuTorch::Inference ( const std::vector< void * > &  input_buffer,
unsigned int  input_size,
std::vector< void * > *  output_buffer,
unsigned int  output_size 
)
virtual

performing network inference

参数
input_buffervector of input tensor
input_sizesize of input_buffer
output_buffervector of output tensor
output_sizesize of output_buffer
返回
init result, true for success

实现了 apollo::prediction::ModelBase.

在文件 multi_agent_vehicle_torch_model.cc96 行定义.

98 {
99 ACHECK(input_size == input_buffer.size() && input_size == 8);
100 ACHECK(output_size == output_buffer->size() && output_size == 1);
101
102 if (init_ == 0) {
103 Init();
104 }
105
106 auto device = torch::Device(torch::kCPU);
107 if (torch::cuda::is_available()) {
108 device = torch::Device(torch::kCUDA);
109 }
110 torch::Tensor target_obstacle_pos =
111 torch::from_blob(input_buffer[0], {1, max_agent_num, 20, 2});
112 torch::Tensor target_obstacle_pos_step =
113 torch::from_blob(input_buffer[1], {1, max_agent_num, 20, 2});
114 torch::Tensor vector_data =
115 torch::from_blob(input_buffer[2], {1, 450, 50, 9});
116 auto options = torch::TensorOptions().dtype(torch::kBool);
117 torch::Tensor vector_mask =
118 torch::from_blob(input_buffer[3], {1, 450, 50}, options);
119 torch::Tensor polyline_mask =
120 torch::from_blob(input_buffer[4], {1, 450}, options);
121 torch::Tensor rand_mask =
122 torch::from_blob(input_buffer[5], {1, 450}, options);
123 torch::Tensor polyline_id =
124 torch::from_blob(input_buffer[6], {1, 450, 2});
125 torch::Tensor obs_position =
126 torch::from_blob(input_buffer[7], {1, max_agent_num, 3});
127
128 std::vector<torch::jit::IValue> torch_inputs;
129
130 torch_inputs.push_back(c10::ivalue::Tuple::create(
131 {std::move(target_obstacle_pos.to(device)),
132 std::move(target_obstacle_pos_step.to(device)),
133 std::move(vector_data.to(device)),
134 std::move(vector_mask.to(device)),
135 std::move(polyline_mask.to(device)),
136 std::move(rand_mask.to(device)),
137 std::move(polyline_id.to(device)),
138 std::move(obs_position.to(device))}));
139
140 torch::Tensor torch_output_tensor =
141 model_instance_.forward(torch_inputs).toTensor().to(torch::kCPU);
142 memcpy((*output_buffer)[0], torch_output_tensor.data_ptr<float>(),
143 1 * max_agent_num * 30 * 2 * sizeof(float));
144 return true;
145}
virtual bool Init()
parse model description class and load the model
first check imutoantoffset saved in device
Definition readme.txt:2
#define ACHECK(cond)
Definition log.h:80

◆ Init()

bool apollo::prediction::MultiAgentVehicleCpuTorch::Init ( )
virtual

parse model description class and load the model

参数
config_pathmodel config path
返回
init result, true for success

实现了 apollo::prediction::ModelBase.

在文件 multi_agent_vehicle_torch_model.cc31 行定义.

31 {
32 ModelConf model_config;
33 int status;
34
35 if (init_ != 0) {
36 return true;
37 }
38
39 std::string class_name =
40 abi::__cxa_demangle(typeid(*this).name(), 0, 0, &status);
41
42 std::string default_config_path =
44 ->GetPluginConfPath<ModelBase>(class_name,
45 "conf/default_conf.pb.txt");
46
47 if (!cyber::common::GetProtoFromFile(default_config_path, &model_config)) {
48 AERROR << "Unable to load model conf file: " << default_config_path;
49 return false;
50 }
51 model_path_ = model_config.model_path();
52 init_ = 1;
53
54 return LoadModel();
55}
std::string GetPluginConfPath(const std::string &class_name, const std::string &conf_name)
get plugin configuration file location
static PluginManager * Instance()
get singleton instance of PluginManager
#define AERROR
Definition log.h:44
bool GetProtoFromFile(const std::string &file_name, google::protobuf::Message *message)
Parses the content of the file specified by the file_name as a representation of protobufs,...
Definition file.cc:132

◆ LoadModel()

bool apollo::prediction::MultiAgentVehicleCpuTorch::LoadModel ( )
virtual

load the model from file

返回
loading result, true for success

实现了 apollo::prediction::ModelBase.

在文件 multi_agent_vehicle_torch_model.cc57 行定义.

57 {
58 auto device = torch::Device(torch::kCPU);
59 if (torch::cuda::is_available()) {
60 device = torch::Device(torch::kCUDA);
61 }
62
63 model_instance_ = torch::jit::load(model_path_, device);
64
65 torch::set_num_threads(1);
66
67 // Fake intput for the first frame
68 torch::Tensor target_obstacle_pos =
69 torch::randn({1, max_agent_num, 20, 2});
70 torch::Tensor target_obstacle_pos_step =
71 torch::randn({1, max_agent_num, 20, 2});
72 torch::Tensor vector_data = torch::randn({1, 450, 50, 9});
73 torch::Tensor vector_mask = torch::randn({1, 450, 50}) > 0.9;
74 torch::Tensor polyline_mask = torch::randn({1, 450}) > 0.9;
75 torch::Tensor rand_mask = torch::zeros({450}).toType(at::kBool);
76 torch::Tensor polyline_id = torch::randn({1, 450, 2});
77 torch::Tensor obs_position = torch::randn({1, max_agent_num, 3});
78 std::vector<torch::jit::IValue> torch_inputs;
79 torch::Tensor torch_default_output_tensor;
80
81 torch_inputs.push_back(c10::ivalue::Tuple::create(
82 {std::move(target_obstacle_pos.to(device)),
83 std::move(target_obstacle_pos_step.to(device)),
84 std::move(vector_data.to(device)),
85 std::move(vector_mask.to(device)),
86 std::move(polyline_mask.to(device)),
87 std::move(rand_mask.to(device)),
88 std::move(polyline_id.to(device)),
89 std::move(obs_position.to(device))}));
90
91 // warm up to avoid very slow first inference later
92 WarmUp(torch_inputs, &model_instance_, &torch_default_output_tensor);
93 return true;
94}
void WarmUp(const std::vector< torch::jit::IValue > &torch_inputs, torch::jit::script::Module *model, at::Tensor *default_output_ptr)
warm up function to avoid slowly inference of torch model
Definition warm_up.cc:28

该类的文档由以下文件生成: