Skip to content

Commit

Permalink
fix thread local4
Browse files Browse the repository at this point in the history
  • Loading branch information
xinyiZzz committed Oct 25, 2023
1 parent 5c82eb6 commit 38dac74
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 66 deletions.
2 changes: 1 addition & 1 deletion be/src/runtime/memory/mem_tracker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ void MemTracker::bind_parent(MemTrackerLimiter* parent) {
if (parent) {
_parent_label = parent->label();
_parent_group_num = parent->group_num();
} else if (doris::thread_context_ptr_init) {
} else if (doris::is_thread_context_init()) {
_parent_label = thread_context()->thread_mem_tracker()->label();
_parent_group_num = thread_context()->thread_mem_tracker()->group_num();
}
Expand Down
2 changes: 1 addition & 1 deletion be/src/runtime/runtime_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ Status RuntimeState::check_query_state(const std::string& msg) {
//
// If the thread MemTrackerLimiter exceeds the limit, an error status is returned.
// Usually used after SCOPED_ATTACH_TASK, during query execution.
if (thread_context_ptr_init && thread_context()->thread_mem_tracker()->limit_exceeded() &&
if (is_thread_context_init() && thread_context()->thread_mem_tracker()->limit_exceeded() &&
!config::enable_query_memory_overcommit) {
auto failed_msg = thread_context()->thread_mem_tracker()->query_tracker_limit_exceeded_str(
thread_context()->thread_mem_tracker()->tracker_limit_exceeded_str(),
Expand Down
161 changes: 103 additions & 58 deletions be/src/runtime/thread_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,11 @@ extern bool k_doris_exit;
extern bthread_key_t btls_key;

// Is true after ThreadContext construction.
inline thread_local bool thread_context_ptr_init = false;
inline thread_local bool pthread_context_ptr_init = false;
inline thread_local constinit ThreadContext* thread_context_ptr;

// To avoid performance problems caused by frequently calling `bthread_getspecific` to obtain bthread TLS
// cache the key and value of bthread TLS in pthread TLS.
inline thread_local constinit ThreadContext* bthread_context;
inline thread_local constinit ThreadContext* bthread_context = nullptr;
inline thread_local bthread_t bthread_id;

// The thread context saves some info about a working thread.
Expand All @@ -143,6 +142,12 @@ class ThreadContext {
}
}

~ThreadContext() {
pthread_context_ptr_init = false;
bthread_context = nullptr;
bthread_id = bthread_self(); // Avoid CONSUME_MEM_TRACKER call pthread_getspecific.
}

void attach_task(const TUniqueId& task_id, const TUniqueId& fragment_instance_id,
const std::shared_ptr<MemTrackerLimiter>& mem_tracker) {
#ifndef BE_TEST
Expand Down Expand Up @@ -200,38 +205,33 @@ class ThreadLocalHandle {
public:
static void handle_thread_local() {
if (bthread_self() == 0) {
if (!thread_context_ptr_init) {
if (!pthread_context_ptr_init) {
DCHECK(bthread_equal(0, bthread_id)); // Not used in bthread before.
thread_context_ptr = new ThreadContext();
thread_context_ptr_init = true;
pthread_context_ptr_init = true;
}
DCHECK(thread_context_ptr != nullptr);
thread_context_ptr->handle_thread_local_count++;
} else {
if (!thread_context_ptr_init) {
// Avoid calling bthread_getspecific frequently to get bthread local.
// Very frequent bthread_getspecific will slow, but handle_thread_local is not expected to be much.
// Cache the pointer of bthread local in pthead local.
bthread_context = static_cast<ThreadContext*>(bthread_getspecific(btls_key));
if (bthread_context == nullptr) {
// A new bthread starts, two scenarios:
// 1. First call to bthread_getspecific (and before any bthread_setspecific) returns NULL
// 2. There are not enough reusable btls in btls pool.
// else, two scenarios:
// 1. A new bthread starts, but get a reuses btls.
// 2. A pthread switch occurs. Because the pthread switch cannot be accurately identified at the moment.
// So tracker call reset 0 like reuses btls.
// during this period, stop the use of thread_context.
bthread_context = new ThreadContext;
// The brpc server should respond as quickly as possible.
bthread_context->thread_mem_tracker_mgr->disable_wait_gc();
// set the data so that next time bthread_getspecific in the thread returns the data.
CHECK((0 == bthread_setspecific(btls_key, bthread_context)) ||
doris::k_doris_exit);
}
DCHECK(bthread_context->handle_thread_local_count == 0);
bthread_id = bthread_self();
thread_context_ptr_init = true;
// Avoid calling bthread_getspecific frequently to get bthread local.
// Very frequent bthread_getspecific will slow, but handle_thread_local is not expected to be much.
// Cache the pointer of bthread local in pthead local.
bthread_id = bthread_self();
bthread_context = static_cast<ThreadContext*>(bthread_getspecific(btls_key));
if (bthread_context == nullptr) {
// A new bthread starts, two scenarios:
// 1. First call to bthread_getspecific (and before any bthread_setspecific) returns NULL
// 2. There are not enough reusable btls in btls pool.
// else, two scenarios:
// 1. A new bthread starts, but get a reuses btls.
// 2. A pthread switch occurs. Because the pthread switch cannot be accurately identified at the moment.
// So tracker call reset 0 like reuses btls.
// during this period, stop the use of thread_context.
bthread_context = new ThreadContext;
// The brpc server should respond as quickly as possible.
bthread_context->thread_mem_tracker_mgr->disable_wait_gc();
// set the data so that next time bthread_getspecific in the thread returns the data.
CHECK((0 == bthread_setspecific(btls_key, bthread_context)) || doris::k_doris_exit);
}
DCHECK(bthread_context != nullptr);
bthread_context->handle_thread_local_count++;
Expand All @@ -241,40 +241,62 @@ class ThreadLocalHandle {
// `handle_thread_local` and `handle_thread_local` should be used in pairs,
// `release_thread_local` should only be called if `handle_thread_local` returns true
static void release_thread_local() {
DCHECK(thread_context_ptr_init);
if (bthread_self() == 0) {
DCHECK(pthread_context_ptr_init);
thread_context_ptr->handle_thread_local_count--;
if (thread_context_ptr->handle_thread_local_count == 0) {
pthread_context_ptr_init = false;
delete doris::thread_context_ptr;
thread_context_ptr = nullptr;
thread_context_ptr_init = false;
}
} else {
if (!bthread_equal(bthread_self(), bthread_id)) {
bthread_id = bthread_self();
bthread_context = static_cast<ThreadContext*>(bthread_getspecific(btls_key));
DCHECK(bthread_context != nullptr);
}
bthread_context->handle_thread_local_count--;
if (bthread_context->handle_thread_local_count == 0) {
bthread_id = 0;
bthread_context = nullptr;
thread_context_ptr_init = false;
}
}
}
};

// must call handle_thread_local() before use thread_context().
static ThreadContext* thread_context() {
DCHECK(thread_context_ptr_init);
if (bthread_self() == 0) {
static bool is_thread_context_init() {
if (pthread_context_ptr_init) {
// in pthread
DCHECK(thread_context_ptr != nullptr);
DCHECK(bthread_equal(0, bthread_id)); // Not used in bthread before.
return thread_context_ptr;
} else {
return true;
} else if (bthread_self() != 0) {
// in bthread
DCHECK(!pthread_context_ptr_init);
if (!bthread_equal(bthread_self(), bthread_id)) {
// bthread switching pthread may be very frequent, remember not to use lock or other time-consuming operations.
bthread_id = bthread_self();
bthread_context = static_cast<ThreadContext*>(bthread_getspecific(btls_key));
}
if (doris::bthread_context == nullptr) {
return false;
} else {
return true;
}
} else {
return false;
}
}

// must call handle_thread_local() and is_thread_context_init() before use thread_context().
static ThreadContext* thread_context() {
if (bthread_self() == 0) {
// in pthread
DCHECK(pthread_context_ptr_init);
DCHECK(thread_context_ptr != nullptr);
return thread_context_ptr;
} else {
// in bthread
DCHECK(bthread_context != nullptr);
return bthread_context;
}
Expand Down Expand Up @@ -328,7 +350,7 @@ class SwitchThreadMemTrackerLimiter {
class TrackMemoryToUnknown {
public:
explicit TrackMemoryToUnknown() {
if (thread_context_ptr_init) {
if (is_thread_context_init()) {
if (bthread_self() != 0) {
_tid = std::this_thread::get_id(); // save pthread id
}
Expand All @@ -339,8 +361,8 @@ class TrackMemoryToUnknown {
}

~TrackMemoryToUnknown() {
if (_old_mem_tracker != nullptr) {
DCHECK(thread_context_ptr_init);
if (is_thread_context_init()) {
DCHECK(_old_mem_tracker != nullptr);
if (bthread_self() != 0) {
// make sure pthread is not switch, if switch, mem tracker will be wrong, but not crash in release
DCHECK(_tid == std::this_thread::get_id());
Expand Down Expand Up @@ -373,12 +395,27 @@ class AddThreadMemTrackerConsumer {
// Basic macros for mem tracker, usually do not need to be modified and used.
#ifdef USE_MEM_TRACKER
// used to fix the tracking accuracy of caches.
#define THREAD_MEM_TRACKER_TRANSFER_TO(size, tracker) \
doris::thread_context()->thread_mem_tracker_mgr->limiter_mem_tracker_raw()->transfer_to( \
size, tracker)
#define THREAD_MEM_TRACKER_TRANSFER_FROM(size, tracker) \
tracker->transfer_to( \
size, doris::thread_context()->thread_mem_tracker_mgr->limiter_mem_tracker_raw())
#define THREAD_MEM_TRACKER_TRANSFER_TO(size, tracker) \
do { \
if (is_thread_context_init()) { \
doris::thread_context() \
->thread_mem_tracker_mgr->limiter_mem_tracker_raw() \
->transfer_to(size, tracker); \
} else { \
doris::ExecEnv::GetInstance()->orphan_mem_tracker_raw()->transfer_to(size, tracker); \
} \
} while (0)

#define THREAD_MEM_TRACKER_TRANSFER_FROM(size, tracker) \
do { \
if (is_thread_context_init()) { \
tracker->transfer_to( \
size, \
doris::thread_context()->thread_mem_tracker_mgr->limiter_mem_tracker_raw()); \
} else { \
tracker->transfer_to(size, doris::ExecEnv::GetInstance()->orphan_mem_tracker_raw()); \
} \
} while (0)

// Mem Hook to consume thread mem tracker
// TODO: In the original design, the MemTracker consume method is called before the memory is allocated.
Expand All @@ -387,21 +424,29 @@ class AddThreadMemTrackerConsumer {
// which is different from the previous behavior.
#define CONSUME_MEM_TRACKER(size) \
do { \
if (doris::thread_context_ptr_init) { \
doris::thread_context()->consume_memory(size); \
if (doris::pthread_context_ptr_init) { \
DCHECK(bthread_self() == 0); \
DCHECK(doris::thread_context_ptr != nullptr); \
DCHECK(bthread_equal(0, doris::bthread_id)); \
doris::thread_context_ptr->consume_memory(size); \
} else if (bthread_self() != 0) { \
DCHECK(!doris::pthread_context_ptr_init); \
if (!bthread_equal(bthread_self(), doris::bthread_id)) { \
doris::bthread_id = bthread_self(); \
doris::bthread_context = \
static_cast<doris::ThreadContext*>(bthread_getspecific(doris::btls_key)); \
} \
if (doris::bthread_context == nullptr) { \
doris::ExecEnv::GetInstance()->orphan_mem_tracker_raw()->consume_no_update_peak( \
size); \
} else { \
doris::bthread_context->consume_memory(size); \
} \
} else if (doris::ExecEnv::GetInstance()->initialized()) { \
doris::ExecEnv::GetInstance()->orphan_mem_tracker_raw()->consume_no_update_peak(size); \
} \
} while (0)
#define RELEASE_MEM_TRACKER(size) \
do { \
if (doris::thread_context_ptr_init) { \
doris::thread_context()->consume_memory(-size); \
} else if (doris::ExecEnv::GetInstance()->initialized()) { \
doris::ExecEnv::GetInstance()->orphan_mem_tracker_raw()->consume_no_update_peak( \
-size); \
} \
} while (0)
#define RELEASE_MEM_TRACKER(size) CONSUME_MEM_TRACKER(-size)
#else
#define THREAD_MEM_TRACKER_TRANSFER_TO(size, tracker) (void)0
#define THREAD_MEM_TRACKER_TRANSFER_FROM(size, tracker) (void)0
Expand Down
12 changes: 6 additions & 6 deletions be/src/vec/common/allocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

template <bool clear_memory_, bool mmap_populate, bool use_mmap>
void Allocator<clear_memory_, mmap_populate, use_mmap>::sys_memory_check(size_t size) const {
if (doris::thread_context_ptr_init && doris::thread_context()->skip_memory_check) {
if (doris::is_thread_context_init() && doris::thread_context()->skip_memory_check) {
return;
}
if (doris::MemTrackerLimiter::sys_mem_exceed_limit_check(size)) {
Expand All @@ -49,10 +49,10 @@ void Allocator<clear_memory_, mmap_populate, use_mmap>::sys_memory_check(size_t
"Allocator sys memory check failed: Cannot alloc:{}, consuming "
"tracker:<{}>, exec node:<{}>, {}.",
size,
doris::thread_context_ptr_init
doris::is_thread_context_init()
? doris::thread_context()->thread_mem_tracker()->label()
: "Orphan",
doris::thread_context_ptr_init
doris::is_thread_context_init()
? doris::thread_context()->thread_mem_tracker_mgr->last_consumer_tracker()
: "",
doris::MemTrackerLimiter::process_limit_exceeded_errmsg_str());
Expand All @@ -62,15 +62,15 @@ void Allocator<clear_memory_, mmap_populate, use_mmap>::sys_memory_check(size_t
}

// TODO, Save the query context in the thread context, instead of finding whether the query id is canceled in fragment_mgr.
if (doris::thread_context_ptr_init &&
if (doris::is_thread_context_init() &&
doris::ExecEnv::GetInstance()->fragment_mgr()->query_is_canceled(
doris::thread_context()->task_id())) {
if (doris::enable_thread_catch_bad_alloc) {
throw doris::Exception(doris::ErrorCode::MEM_ALLOC_FAILED, err_msg);
}
return;
}
if (doris::thread_context_ptr_init && !doris::config::disable_memory_gc &&
if (doris::is_thread_context_init() && !doris::config::disable_memory_gc &&
doris::thread_context()->thread_mem_tracker_mgr->is_attach_query() &&
doris::thread_context()->thread_mem_tracker_mgr->wait_gc()) {
int64_t wait_milliseconds = 0;
Expand Down Expand Up @@ -126,7 +126,7 @@ void Allocator<clear_memory_, mmap_populate, use_mmap>::sys_memory_check(size_t

template <bool clear_memory_, bool mmap_populate, bool use_mmap>
void Allocator<clear_memory_, mmap_populate, use_mmap>::memory_tracker_check(size_t size) const {
if (doris::thread_context_ptr_init && doris::thread_context()->skip_memory_check) {
if (doris::is_thread_context_init() && doris::thread_context()->skip_memory_check) {
return;
}
auto st = doris::thread_context()->thread_mem_tracker()->check_limit(size);
Expand Down

0 comments on commit 38dac74

Please sign in to comment.