Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more safety checks to BlockingCounter #3019

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion tsl/platform/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ cc_library(
compatible_with = get_compatible_with_portable(),
deps = [
":logging",
":mutex",
":stacktrace",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
],
)

Expand Down
77 changes: 64 additions & 13 deletions tsl/platform/blocking_counter.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,14 @@ limitations under the License.
#define TENSORFLOW_TSL_PLATFORM_BLOCKING_COUNTER_H_

#include <atomic>
#include <chrono> // NOLINT
#include <cstdint>

#include "absl/base/thread_annotations.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/mutex.h"
#include "tsl/platform/stacktrace.h"

namespace tsl {

Expand All @@ -28,51 +33,97 @@ class BlockingCounter {
BlockingCounter(int initial_count)
: state_(initial_count << 1), notified_(false) {
CHECK_GE(initial_count, 0);
DCHECK_EQ((initial_count << 1) >> 1, initial_count);
DCHECK_EQ((static_cast<unsigned int>(initial_count) << 1) >> 1,
initial_count);
}

~BlockingCounter() {}
~BlockingCounter() = default;

static thread_local constexpr char kNonce = 0;

inline void DecrementCount() {
unsigned int v = state_.fetch_sub(2, std::memory_order_acq_rel) - 2;
if (v != 1) {
DCHECK_NE(((v + 2) & ~1), 0);
return; // either count has not dropped to 0, or waiter is not waiting
}
mutex_lock l(mu_);
absl::MutexLock l(&mu_);
DCHECK(!notified_);
notified_ = true;
cond_var_.notify_all();
cond_var_.SignalAll();
}

inline void Wait() {
LOG(INFO) << "kNonce: " << (const void*)&kNonce;
const void* prior_last_waiter_addr =
last_waiter_addr_.load(std::memory_order_relaxed);
if (prior_last_waiter_addr != nullptr) {
CHECK_EQ(prior_last_waiter_addr, (const void*)&kNonce)
<< "multiple threads called WaitFor()";
} else {
auto expected = prior_last_waiter_addr;
if (!last_waiter_addr_.compare_exchange_strong(
expected, &kNonce, std::memory_order_relaxed)) {
LOG(FATAL) << "Tried to swap " << prior_last_waiter_addr << " with "
<< (const void*)&kNonce << " but found " << expected;
}
}
unsigned int v = state_.fetch_or(1, std::memory_order_acq_rel);
if ((v >> 1) == 0) return;
mutex_lock l(mu_);
absl::MutexLock l(&mu_);

// only one thread may call Wait(). To support more than one thread,
// implement a counter num_to_exit, like in the Barrier class.
CHECK_EQ(num_waiting_, 0) << "multiple threads called Wait()";
num_waiting_++;

while (!notified_) {
cond_var_.wait(l);
cond_var_.Wait(&mu_);
}
}
// Wait for the specified time, return false iff the count has not dropped to
// zero before the timeout expired.
inline bool WaitFor(std::chrono::milliseconds ms) {
LOG(INFO) << "this: " << this << " kNonce: " << (const void*)&kNonce;
const void* prior_last_waiter_addr =
last_waiter_addr_.load(std::memory_order_relaxed);
if (prior_last_waiter_addr != nullptr) {
CHECK_EQ(prior_last_waiter_addr, (const void*)&kNonce)
<< "multiple threads called WaitFor(): " << last_waiter_addr_ << " "
<< &kNonce;
} else {
auto expected = prior_last_waiter_addr;
if (!last_waiter_addr_.compare_exchange_strong(
expected, &kNonce, std::memory_order_relaxed)) {
LOG(FATAL) << "Tried to swap " << prior_last_waiter_addr << " with "
<< (const void*)&kNonce << " but found " << expected;
}
LOG(INFO) << tsl::CurrentStackTrace();
}

unsigned int v = state_.fetch_or(1, std::memory_order_acq_rel);
if ((v >> 1) == 0) return true;
mutex_lock l(mu_);
absl::Duration timeout = absl::FromChrono(ms);
absl::MutexLock l(&mu_);

// only one thread may call Wait(). To support more than one thread,
// implement a counter num_to_exit, like in the Barrier class.

while (!notified_) {
const std::cv_status status = cond_var_.wait_for(l, ms);
if (status == std::cv_status::timeout) {
if (cond_var_.WaitWithTimeout(&mu_, timeout)) {
return false;
}
}
return true;
}

private:
mutex mu_;
condition_variable cond_var_;
absl::Mutex mu_;
absl::CondVar cond_var_;
std::atomic<int> state_; // low bit is waiter flag
bool notified_;
std::atomic<const void*> last_waiter_addr_ = nullptr;
int num_waiting_ ABSL_GUARDED_BY(mu_) = 0;
bool notified_ ABSL_GUARDED_BY(mu_);
};

} // namespace tsl
Expand Down