Skip to content

Commit

Permalink
Add more safety checks to BlockingCounter
Browse files Browse the repository at this point in the history
This makes it harder to have use-after-free situations.

PiperOrigin-RevId: 705643798
  • Loading branch information
majnemer authored and copybara-github committed Dec 13, 2024
1 parent 83c149d commit f182dba
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 14 deletions.
5 changes: 4 additions & 1 deletion tsl/platform/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ cc_library(
compatible_with = get_compatible_with_portable(),
deps = [
":logging",
":mutex",
":stacktrace",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
],
)

Expand Down
77 changes: 64 additions & 13 deletions tsl/platform/blocking_counter.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,14 @@ limitations under the License.
#define TENSORFLOW_TSL_PLATFORM_BLOCKING_COUNTER_H_

#include <atomic>
#include <chrono> // NOLINT
#include <cstdint>

#include "absl/base/thread_annotations.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/mutex.h"
#include "tsl/platform/stacktrace.h"

namespace tsl {

Expand All @@ -28,51 +33,97 @@ class BlockingCounter {
BlockingCounter(int initial_count)
: state_(initial_count << 1), notified_(false) {
CHECK_GE(initial_count, 0);
DCHECK_EQ((initial_count << 1) >> 1, initial_count);
DCHECK_EQ((static_cast<unsigned int>(initial_count) << 1) >> 1,
initial_count);
}

~BlockingCounter() {}
~BlockingCounter() = default;

static thread_local constexpr char kNonce = 0;

inline void DecrementCount() {
unsigned int v = state_.fetch_sub(2, std::memory_order_acq_rel) - 2;
if (v != 1) {
DCHECK_NE(((v + 2) & ~1), 0);
return; // either count has not dropped to 0, or waiter is not waiting
}
mutex_lock l(mu_);
absl::MutexLock l(&mu_);
DCHECK(!notified_);
notified_ = true;
cond_var_.notify_all();
cond_var_.SignalAll();
}

inline void Wait() {
LOG(INFO) << "kNonce: " << (const void*)&kNonce;
const void* prior_last_waiter_addr =
last_waiter_addr_.load(std::memory_order_relaxed);
if (prior_last_waiter_addr != nullptr) {
CHECK_EQ(prior_last_waiter_addr, (const void*)&kNonce)
<< "multiple threads called WaitFor()";
} else {
auto expected = prior_last_waiter_addr;
if (!last_waiter_addr_.compare_exchange_strong(
expected, &kNonce, std::memory_order_relaxed)) {
LOG(FATAL) << "Tried to swap " << prior_last_waiter_addr << " with "
<< (const void*)&kNonce << " but found " << expected;
}
}
unsigned int v = state_.fetch_or(1, std::memory_order_acq_rel);
if ((v >> 1) == 0) return;
mutex_lock l(mu_);
absl::MutexLock l(&mu_);

// only one thread may call Wait(). To support more than one thread,
// implement a counter num_to_exit, like in the Barrier class.
CHECK_EQ(num_waiting_, 0) << "multiple threads called Wait()";
num_waiting_++;

while (!notified_) {
cond_var_.wait(l);
cond_var_.Wait(&mu_);
}
}
// Wait for the specified time, return false iff the count has not dropped to
// zero before the timeout expired.
inline bool WaitFor(std::chrono::milliseconds ms) {
LOG(INFO) << "this: " << this << " kNonce: " << (const void*)&kNonce;
const void* prior_last_waiter_addr =
last_waiter_addr_.load(std::memory_order_relaxed);
if (prior_last_waiter_addr != nullptr) {
CHECK_EQ(prior_last_waiter_addr, (const void*)&kNonce)
<< "multiple threads called WaitFor(): " << last_waiter_addr_ << " "
<< &kNonce;
} else {
auto expected = prior_last_waiter_addr;
if (!last_waiter_addr_.compare_exchange_strong(
expected, &kNonce, std::memory_order_relaxed)) {
LOG(FATAL) << "Tried to swap " << prior_last_waiter_addr << " with "
<< (const void*)&kNonce << " but found " << expected;
}
LOG(INFO) << tsl::CurrentStackTrace();
}

unsigned int v = state_.fetch_or(1, std::memory_order_acq_rel);
if ((v >> 1) == 0) return true;
mutex_lock l(mu_);
absl::Duration timeout = absl::FromChrono(ms);
absl::MutexLock l(&mu_);

// only one thread may call Wait(). To support more than one thread,
// implement a counter num_to_exit, like in the Barrier class.

while (!notified_) {
const std::cv_status status = cond_var_.wait_for(l, ms);
if (status == std::cv_status::timeout) {
if (cond_var_.WaitWithTimeout(&mu_, timeout)) {
return false;
}
}
return true;
}

private:
mutex mu_;
condition_variable cond_var_;
absl::Mutex mu_;
absl::CondVar cond_var_;
std::atomic<int> state_; // low bit is waiter flag
bool notified_;
std::atomic<const void*> last_waiter_addr_ = nullptr;
int num_waiting_ ABSL_GUARDED_BY(mu_) = 0;
bool notified_ ABSL_GUARDED_BY(mu_);
};

} // namespace tsl
Expand Down

0 comments on commit f182dba

Please sign in to comment.