Apollo 10.0
自动驾驶开放平台
protobuf_arena_manager.cc
浏览该文件的文档.
1/******************************************************************************
2 * Copyright 2024 The Apollo Authors. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *****************************************************************************/
17
18#include <sys/ipc.h>
19#include <sys/shm.h>
20
21#include <cstring>
22#include <string>
23
24#include <google/protobuf/arena.h>
25#include <google/protobuf/message.h>
26
28
29namespace apollo {
30namespace cyber {
31namespace transport {
32
33const int32_t ArenaSegmentBlock::kRWLockFree = 0;
34const int32_t ArenaSegmentBlock::kWriteExclusive = -1;
36
38 : channel_id_(0), key_id_(0), base_address_(nullptr) {}
39
40ArenaSegment::ArenaSegment(uint64_t channel_id)
41 : channel_id_(channel_id),
42 key_id_(std::hash<std::string>{}("/apollo/__arena__/" +
43 std::to_string(channel_id))) {}
44
45ArenaSegment::ArenaSegment(uint64_t channel_id, void* base_address)
46 : channel_id_(channel_id),
47 key_id_(std::hash<std::string>{}("/apollo/__arena__/" +
48 std::to_string(channel_id))),
49 base_address_(base_address) {}
50
51ArenaSegment::ArenaSegment(uint64_t channel_id, uint64_t message_size,
52 uint64_t block_num, void* base_address)
53 : channel_id_(channel_id),
54 key_id_(std::hash<std::string>{}("/apollo/__arena__/" +
55 std::to_string(channel_id))),
56 base_address_(base_address) {
57 Init(message_size, block_num);
58}
59
61
62bool ArenaSegment::Init(uint64_t message_size, uint64_t block_num) {
63 uint64_t key_id = std::hash<std::string>{}("/apollo/__arena__/" +
64 std::to_string(channel_id_));
65 // fprintf(stderr, "channel_id: %lx, key_id: %lx\n", channel_id_, key_id);
66
67 for (uint32_t retry = 0; retry < 2 && !OpenOrCreate(message_size, block_num);
68 ++retry) {
69 }
70
71 return true;
72}
73
74bool ArenaSegment::OpenOrCreate(uint64_t message_size, uint64_t block_num) {
75 auto arena_conf =
76 cyber::common::GlobalData::Instance()->GetChannelArenaConf(channel_id_);
77 auto shared_buffer_size = arena_conf.shared_buffer_size();
78 auto size = sizeof(ArenaSegmentState) +
79 sizeof(ArenaSegmentBlockDescriptor) * block_num +
80 message_size * block_num + shared_buffer_size;
81 auto shmid =
82 shmget(static_cast<key_t>(key_id_), size, 0644 | IPC_CREAT | IPC_EXCL);
83 if (shmid == -1) {
84 if (errno == EINVAL) {
85 // TODO(all): need larger space, recreate
86 } else if (errno == EEXIST) {
87 // TODO(all): shm already exist, open only
88 return Open(message_size, block_num);
89 } else {
90 // create or open shm failed
91 return false;
92 }
93 }
95 shm_address_ = shmat(shmid, base_address_, 0);
96 if (shm_address_ == reinterpret_cast<void*>(-1)) {
97 // shmat failed
98 return false;
99 }
100
101 arenas_.resize(block_num, nullptr);
102 if (shared_buffer_size == 0) {
103 shared_buffer_arena_ = nullptr;
104 } else {
105 google::protobuf::ArenaOptions options;
106 options.start_block_size = shared_buffer_size;
107 options.max_block_size = shared_buffer_size;
108 options.initial_block = reinterpret_cast<char*>(
109 reinterpret_cast<uint64_t>(shm_address_) + sizeof(ArenaSegmentState) +
110 block_num * sizeof(ArenaSegmentBlock) + block_num * message_size);
111 options.initial_block_size = shared_buffer_size;
112 shared_buffer_arena_ = std::make_shared<google::protobuf::Arena>(options);
113 }
114 for (size_t i = 0; i < block_num; i++) {
115 arena_block_address_.push_back(
116 reinterpret_cast<uint64_t>(shm_address_) + sizeof(ArenaSegmentState) +
117 block_num * sizeof(ArenaSegmentBlock) + i * message_size);
118 }
119
120 state_ = reinterpret_cast<ArenaSegmentState*>(shm_address_);
121 state_->struct_.ref_count_.store(1);
122 state_->struct_.auto_extended_.store(false);
124 state_->struct_.block_num_.store(block_num);
125 state_->struct_.message_seq_.store(0);
126 blocks_ = reinterpret_cast<ArenaSegmentBlock*>(
127 reinterpret_cast<uint64_t>(shm_address_) + sizeof(ArenaSegmentState));
128 for (uint64_t i = 0; i < block_num; ++i) {
130 // blocks_[i].writing_ref_count_.store(0);
131 // blocks_[i].reading_ref_count_.store(0);
132 blocks_[i].lock_num_.store(0);
133 }
134 return true;
135}
136
137bool ArenaSegment::Open(uint64_t message_size, uint64_t block_num) {
138 auto arena_conf =
139 cyber::common::GlobalData::Instance()->GetChannelArenaConf(channel_id_);
140 auto shared_buffer_size = arena_conf.shared_buffer_size();
141 auto shmid = shmget(static_cast<key_t>(key_id_), 0, 0644);
142 if (shmid == -1) {
143 // shm not exist
144 return false;
145 }
146 shm_address_ = shmat(shmid, base_address_, 0);
147 if (shm_address_ == reinterpret_cast<void*>(-1)) {
148 // shmat failed
149 return false;
150 }
152 state_ = reinterpret_cast<ArenaSegmentState*>(shm_address_);
153 state_->struct_.ref_count_.fetch_add(1);
154 blocks_ = reinterpret_cast<ArenaSegmentBlock*>(
155 reinterpret_cast<uint64_t>(shm_address_) + sizeof(ArenaSegmentState));
156
157 arenas_.resize(block_num, nullptr);
158 if (shared_buffer_size == 0) {
159 shared_buffer_arena_ = nullptr;
160 } else {
161 google::protobuf::ArenaOptions options;
162 options.start_block_size = shared_buffer_size;
163 options.max_block_size = shared_buffer_size;
164 options.initial_block = reinterpret_cast<char*>(
165 reinterpret_cast<uint64_t>(shm_address_) + sizeof(ArenaSegmentState) +
166 block_num * sizeof(ArenaSegmentBlock) + block_num * message_size);
167 options.initial_block_size = shared_buffer_size;
168 shared_buffer_arena_ = std::make_shared<google::protobuf::Arena>(options);
169 }
170 for (size_t i = 0; i < block_num; i++) {
171 arena_block_address_.push_back(
172 reinterpret_cast<uint64_t>(shm_address_) + sizeof(ArenaSegmentState) +
173 block_num * sizeof(ArenaSegmentBlock) + i * message_size);
174 }
175 return true;
176}
177
179
181 const auto block_num = state_->struct_.block_num_.load();
182 while (1) {
183 uint64_t next_idx = state_->struct_.message_seq_.fetch_add(1) % block_num;
184 if (AddBlockWriteLock(next_idx)) {
185 return next_idx;
186 }
187 }
188 return 0;
189}
190
191bool ArenaSegment::AddBlockWriteLock(uint64_t block_index) {
192 // base::WriteLockGuard<base::PthreadRWLock> lock(
193 // blocks_[block_index].read_write_mutex_);
194 // if (blocks_[block_index].writing_ref_count_.load() > 0) {
195 // return false;
196 // }
197 // if (blocks_[block_index].reading_ref_count_.load() > 0) {
198 // return false;
199 // }
200 // blocks_[block_index].writing_ref_count_.fetch_add(1);
201 auto& block = blocks_[block_index];
202 int32_t rw_lock_free = ArenaSegmentBlock::kRWLockFree;
203 if (!block.lock_num_.compare_exchange_weak(
205 std::memory_order_acq_rel, std::memory_order_relaxed)) {
206 ADEBUG << "lock num: " << block.lock_num_.load();
207 return false;
208 }
209 return true;
210}
211
212void ArenaSegment::RemoveBlockWriteLock(uint64_t block_index) {
213 blocks_[block_index].lock_num_.fetch_add(1);
214}
215
216bool ArenaSegment::AddBlockReadLock(uint64_t block_index) {
217 // when multiple readers are reading an arena channel
218 // at the same time, using write locks can cause one
219 // or more readers to hang at the lock, whereas
220 // read locks do not have this problem
221 // base::ReadLockGuard<base::PthreadRWLock> lock(
222 // blocks_[block_index].read_write_mutex_);
223 // if (blocks_[block_index].writing_ref_count_.load() > 0) {
224 // return false;
225 // }
226 // blocks_[block_index].reading_ref_count_.fetch_add(1);
227 auto& block = blocks_[block_index];
228 int32_t lock_num = block.lock_num_.load();
229 if (lock_num < ArenaSegmentBlock::kRWLockFree) {
230 AINFO << "block is being written.";
231 return false;
232 }
233
234 int32_t try_times = 0;
235 while (!block.lock_num_.compare_exchange_weak(lock_num, lock_num + 1,
236 std::memory_order_acq_rel,
237 std::memory_order_relaxed)) {
238 ++try_times;
239 if (try_times == ArenaSegmentBlock::kMaxTryLockTimes) {
240 AINFO << "fail to add read lock num, curr num: " << lock_num;
241 return false;
242 }
243
244 lock_num = block.lock_num_.load();
245 if (lock_num < ArenaSegmentBlock::kRWLockFree) {
246 AINFO << "block is being written.";
247 return false;
248 }
249 }
250 return true;
251}
252
253void ArenaSegment::RemoveBlockReadLock(uint64_t block_index) {
254 blocks_[block_index].lock_num_.fetch_sub(1);
255}
256
258 ArenaSegmentBlockInfo* block_info) {
259 if (!block_info) {
260 return false;
261 }
262 if (!state_ || !blocks_) {
263 return false;
264 }
265
266 // TODO(all): support dynamic block size
267
268 uint64_t block_num = state_->struct_.block_num_.load();
269 uint64_t block_size = state_->struct_.message_size_.load();
270 uint64_t block_index = GetNextWritableBlockIndex();
271 block_info->block_index_ = block_index;
272 block_info->block_ = &blocks_[block_index];
273 block_info->block_buffer_address_ = reinterpret_cast<void*>(
274 reinterpret_cast<uint64_t>(shm_address_) + sizeof(ArenaSegmentState) +
275 block_num * sizeof(ArenaSegmentBlock) + block_index * block_size);
276 return true;
277}
278
280 const ArenaSegmentBlockInfo& block_info) {
281 if (!state_ || !blocks_) {
282 return;
283 }
284 if (block_info.block_index_ >= state_->struct_.block_num_.load()) {
285 return;
286 }
288}
289
291 if (!block_info) {
292 return false;
293 }
294 if (!state_ || !blocks_) {
295 return false;
296 }
297
298 if (block_info->block_index_ >= state_->struct_.block_num_.load()) {
299 return false;
300 }
301
302 // TODO(all): support dynamic block size
303
304 if (!AddBlockReadLock(block_info->block_index_)) {
305 return false;
306 }
307 uint64_t block_num = state_->struct_.block_num_.load();
308 uint64_t block_size = state_->struct_.message_size_.load();
309
310 block_info->block_ = &blocks_[block_info->block_index_];
311 block_info->block_buffer_address_ = reinterpret_cast<void*>(
312 reinterpret_cast<uint64_t>(shm_address_) + sizeof(ArenaSegmentState) +
313 block_num * sizeof(ArenaSegmentBlock) +
314 block_info->block_index_ * block_size);
315 return true;
316}
317
319 if (!state_ || !blocks_) {
320 return;
321 }
322 if (block_info.block_index_ >= state_->struct_.block_num_.load()) {
323 return;
324 }
326}
327
328ProtobufArenaManager::ProtobufArenaManager() {
329 address_allocator_ = std::make_shared<ArenaAddressAllocator>();
330}
331
333
335 const message::ArenaMessageWrapper* wrapper) {
336 return 0;
337}
338
339std::shared_ptr<ArenaSegment> ProtobufArenaManager::GetSegment(
340 uint64_t channel_id) {
341 std::lock_guard<std::mutex> lock(segments_mutex_);
342 if (segments_.find(channel_id) == segments_.end()) {
343 return nullptr;
344 }
345 return segments_[channel_id];
346}
347
349 const void* message) {
350 auto input_msg = reinterpret_cast<const google::protobuf::Message*>(message);
351 auto channel_id = GetMessageChannelId(wrapper);
352 auto segment = GetSegment(channel_id);
353 auto arena_ptr = input_msg->GetArena();
354 google::protobuf::ArenaOptions options;
355
356 if (!segment) {
357 return nullptr;
358 }
359
360 void* msg_output = nullptr;
361 if (arena_ptr == nullptr) {
362 auto arena_conf =
363 cyber::common::GlobalData::Instance()->GetChannelArenaConf(channel_id);
364 google::protobuf::ArenaOptions options;
365 options.start_block_size = arena_conf.max_msg_size();
366 options.max_block_size = arena_conf.max_msg_size();
367
368 if (!segment) {
369 return nullptr;
370 }
371
373
375 // TODO(all): AcquireBlockToWrite for dynamic adjust block
376 // auto size = input_msg->ByteSizeLong();
377 uint64_t size = 0;
378 if (!segment->AcquireBlockToWrite(size, &wb)) {
379 return nullptr;
380 }
381 this->AddMessageRelatedBlock(wrapper, wb.block_index_);
382 options.initial_block =
383 reinterpret_cast<char*>(segment->arena_block_address_[wb.block_index_]);
384 options.initial_block_size = segment->message_capacity_;
385 if (segment->arenas_[wb.block_index_] != nullptr) {
386 segment->arenas_[wb.block_index_] = nullptr;
387 }
388 segment->arenas_[wb.block_index_] =
389 std::make_shared<google::protobuf::Arena>(options);
390 auto msg = input_msg->New(segment->arenas_[wb.block_index_].get());
391 msg->CopyFrom(*input_msg);
393 wrapper, reinterpret_cast<uint64_t>(msg) -
394 reinterpret_cast<uint64_t>(segment->GetShmAddress()));
395 msg_output = reinterpret_cast<void*>(msg);
396 segment->ReleaseWrittenBlock(wb);
397 } else {
399 int block_index = -1;
400 for (size_t i = 0; i < segment->message_capacity_; i++) {
401 if (segment->arenas_[i].get() == arena_ptr) {
402 block_index = i;
403 break;
404 }
405 }
406 if (block_index == -1) {
407 return nullptr;
408 }
409 wb.block_index_ = block_index;
411 this->AddMessageRelatedBlock(wrapper, block_index);
413 wrapper, reinterpret_cast<uint64_t>(input_msg) -
414 reinterpret_cast<uint64_t>(segment->GetShmAddress()));
415 msg_output = reinterpret_cast<void*>(
416 const_cast<google::protobuf::Message*>(input_msg));
417 segment->ReleaseWrittenBlock(wb);
418 }
419
420 return msg_output;
421}
422
424 auto segment = GetSegment(GetMessageChannelId(wrapper));
425 if (!segment) {
426 return nullptr;
427 }
428
429 auto address = reinterpret_cast<uint64_t>(segment->GetShmAddress()) +
431
432 return reinterpret_cast<void*>(address);
433}
434
436 if (init_) {
437 return true;
438 }
439
440 // do something
441
442 init_ = true;
443 return true;
444}
445
446bool ProtobufArenaManager::EnableSegment(uint64_t channel_id) {
447 if (segments_.find(channel_id) != segments_.end()) {
448 if (arena_buffer_callbacks_.find(channel_id) !=
449 arena_buffer_callbacks_.end()) {
450 arena_buffer_callbacks_[channel_id]();
451 }
452 return true;
453 }
454
455 // uint64_t asociated_channel_id = channel_id + 1;
456 // auto segment = SegmentFactory::CreateSegment(asociated_channel_id);
457 // segment->InitOnly(10 * 1024);
458 auto cyber_config = apollo::cyber::common::GlobalData::Instance()->Config();
459 if (!cyber_config.has_transport_conf()) {
460 return false;
461 }
462 if (!cyber_config.transport_conf().has_shm_conf()) {
463 return false;
464 }
465 if (!cyber_config.transport_conf().shm_conf().has_arena_shm_conf()) {
466 return false;
467 }
468 if (!cyber::common::GlobalData::Instance()->IsChannelEnableArenaShm(
469 channel_id)) {
470 return false;
471 }
472 auto arena_conf =
473 cyber::common::GlobalData::Instance()->GetChannelArenaConf(channel_id);
474 auto segment_shm_address = address_allocator_->Allocate(channel_id);
475 auto segment = std::make_shared<ArenaSegment>(
476 channel_id, arena_conf.max_msg_size(), arena_conf.max_pool_size(),
477 reinterpret_cast<void*>(segment_shm_address));
478 segments_[channel_id] = segment;
479 if (arena_buffer_callbacks_.find(channel_id) !=
480 arena_buffer_callbacks_.end()) {
481 arena_buffer_callbacks_[channel_id]();
482 }
483 return true;
484}
485
487 if (!init_) {
488 return true;
489 }
490
491 for (auto& segment : segments_) {
492 address_allocator_->Deallocate(segment.first);
493 }
494 for (auto& buffer : non_arena_buffers_) {
495 delete buffer.second;
496 }
497 segments_.clear();
498
499 init_ = false;
500 return true;
501}
502
504 message::ArenaMessageWrapper* wrapper, uint64_t channel_id) {
505 wrapper->GetExtended<ExtendedStruct>()->meta_.channel_id_ = channel_id;
506}
507
512
514 message::ArenaMessageWrapper* wrapper, uint64_t address_offset) {
515 wrapper->GetExtended<ExtendedStruct>()->meta_.address_offset_ =
516 address_offset;
517}
518
523
526 std::vector<uint64_t> related_blocks;
527 auto extended = wrapper->GetExtended<ExtendedStruct>();
528 for (uint64_t i = 0; i < extended->meta_.related_blocks_size_; ++i) {
529 related_blocks.push_back(extended->meta_.related_blocks_[i]);
530 }
531 return related_blocks;
532}
533
536 auto extended = wrapper->GetExtended<ExtendedStruct>();
537 extended->meta_.related_blocks_size_ = 0;
538 // memset(extended->meta_.related_blocks_, 0,
539 // sizeof(extended->meta_.related_blocks_));
540}
541
543 message::ArenaMessageWrapper* wrapper, uint64_t block_index) {
544 auto extended = wrapper->GetExtended<ExtendedStruct>();
545 if (extended->meta_.related_blocks_size_ >=
546 sizeof(extended->meta_.related_blocks_) / sizeof(uint64_t)) {
547 return;
548 }
549 extended->meta_.related_blocks_[extended->meta_.related_blocks_size_++] =
550 block_index;
551}
552
553ProtobufArenaManager::ArenaAllocCallback ProtobufArenaManager::arena_alloc_cb_ =
554 nullptr;
555
556void* ProtobufArenaManager::ArenaAlloc(uint64_t size) {
557 return arena_alloc_cb_ ? arena_alloc_cb_(size) : nullptr;
558}
559
560void ProtobufArenaManager::ArenaDealloc(void* addr, uint64_t size) {}
561
562} // namespace transport
563} // namespace cyber
564} // namespace apollo
bool Init(uint64_t message_size, uint64_t block_num)
bool Open(uint64_t message_size, uint64_t block_num)
std::vector< std::shared_ptr< google::protobuf::Arena > > arenas_
bool OpenOrCreate(uint64_t message_size, uint64_t block_num)
std::shared_ptr< google::protobuf::Arena > shared_buffer_arena_
bool AcquireBlockToRead(ArenaSegmentBlockInfo *block_info)
void ReleaseReadBlock(const ArenaSegmentBlockInfo &block_info)
void ReleaseWrittenBlock(const ArenaSegmentBlockInfo &block_info)
bool AcquireBlockToWrite(uint64_t size, ArenaSegmentBlockInfo *block_info)
uint64_t GetBaseAddress(const message::ArenaMessageWrapper *wrapper) override
std::shared_ptr< ArenaSegment > GetSegment(uint64_t channel_id)
void ResetMessageRelatedBlocks(message::ArenaMessageWrapper *wrapper)
void * GetMessage(message::ArenaMessageWrapper *wrapper) override
void SetMessageAddressOffset(message::ArenaMessageWrapper *wrapper, uint64_t offset)
uint64_t GetMessageAddressOffset(message::ArenaMessageWrapper *wrapper)
void SetMessageChannelId(message::ArenaMessageWrapper *wrapper, uint64_t channel_id)
uint64_t GetMessageChannelId(message::ArenaMessageWrapper *wrapper)
void * SetMessage(message::ArenaMessageWrapper *wrapper, const void *message) override
void AddMessageRelatedBlock(message::ArenaMessageWrapper *wrapper, uint64_t block_index)
std::vector< uint64_t > GetMessageRelatedBlocks(message::ArenaMessageWrapper *wrapper)
int message_size
#define ADEBUG
Definition log.h:41
#define AINFO
Definition log.h:42
bool Init(const char *binary_name, const std::string &dag_info)
Definition init.cc:98
class register implement
Definition arena_queue.h:37
Definition future.h:29
struct apollo::cyber::transport::ArenaSegmentState::@3 struct_
struct apollo::cyber::transport::ExtendedStruct::@5 meta_