// Copyright 2017 The Abseil Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "absl/synchronization/notification.h" #include // NOLINT(build/c++11) #include #include #include "gtest/gtest.h" #include "absl/base/attributes.h" #include "absl/base/config.h" #include "absl/base/internal/tracing.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" namespace absl { ABSL_NAMESPACE_BEGIN // A thread-safe class that holds a counter. class ThreadSafeCounter { public: ThreadSafeCounter() : count_(0) {} void Increment() { MutexLock lock(&mutex_); ++count_; } int Get() const { MutexLock lock(&mutex_); return count_; } void WaitUntilGreaterOrEqual(int n) { MutexLock lock(&mutex_); auto cond = [this, n]() { return count_ >= n; }; mutex_.Await(Condition(&cond)); } private: mutable Mutex mutex_; int count_; }; // Runs the |i|'th worker thread for the tests in BasicTests(). Increments the // |ready_counter|, waits on the |notification|, and then increments the // |done_counter|. static void RunWorker(int i, ThreadSafeCounter* ready_counter, Notification* notification, ThreadSafeCounter* done_counter) { ready_counter->Increment(); notification->WaitForNotification(); done_counter->Increment(); } // Tests that the |notification| properly blocks and awakens threads. Assumes // that the |notification| is not yet triggered. If |notify_before_waiting| is // true, the |notification| is triggered before any threads are created, so the // threads never block in WaitForNotification(). Otherwise, the |notification| // is triggered at a later point when most threads are likely to be blocking in // WaitForNotification(). static void BasicTests(bool notify_before_waiting, Notification* notification) { EXPECT_FALSE(notification->HasBeenNotified()); EXPECT_FALSE( notification->WaitForNotificationWithTimeout(absl::Milliseconds(0))); EXPECT_FALSE(notification->WaitForNotificationWithDeadline(absl::Now())); const absl::Duration delay = absl::Milliseconds(50); const absl::Time start = absl::Now(); EXPECT_FALSE(notification->WaitForNotificationWithTimeout(delay)); const absl::Duration elapsed = absl::Now() - start; // Allow for a slight early return, to account for quality of implementation // issues on various platforms. const absl::Duration slop = absl::Milliseconds(5); EXPECT_LE(delay - slop, elapsed) << "WaitForNotificationWithTimeout returned " << delay - elapsed << " early (with " << slop << " slop), start time was " << start; ThreadSafeCounter ready_counter; ThreadSafeCounter done_counter; if (notify_before_waiting) { notification->Notify(); } // Create a bunch of threads that increment the |done_counter| after being // notified. const int kNumThreads = 10; std::vector workers; for (int i = 0; i < kNumThreads; ++i) { workers.push_back(std::thread(&RunWorker, i, &ready_counter, notification, &done_counter)); } if (!notify_before_waiting) { ready_counter.WaitUntilGreaterOrEqual(kNumThreads); // Workers have not been notified yet, so the |done_counter| should be // unmodified. EXPECT_EQ(0, done_counter.Get()); notification->Notify(); } // After notifying and then joining the workers, both counters should be // fully incremented. notification->WaitForNotification(); // should exit immediately EXPECT_TRUE(notification->HasBeenNotified()); EXPECT_TRUE(notification->WaitForNotificationWithTimeout(absl::Seconds(0))); EXPECT_TRUE(notification->WaitForNotificationWithDeadline(absl::Now())); for (std::thread& worker : workers) { worker.join(); } EXPECT_EQ(kNumThreads, ready_counter.Get()); EXPECT_EQ(kNumThreads, done_counter.Get()); } TEST(NotificationTest, SanityTest) { Notification local_notification1, local_notification2; BasicTests(false, &local_notification1); BasicTests(true, &local_notification2); } #if ABSL_HAVE_ATTRIBUTE_WEAK namespace base_internal { namespace { using TraceRecord = std::tuple; thread_local TraceRecord tls_signal; thread_local TraceRecord tls_wait; thread_local TraceRecord tls_continue; thread_local TraceRecord tls_observed; } // namespace // Strong extern "C" implementation. extern "C" { void ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceWait)(const void* object, ObjectKind kind) { tls_wait = {object, kind}; } void ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceContinue)(const void* object, ObjectKind kind) { tls_continue = {object, kind}; } void ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceSignal)(const void* object, ObjectKind kind) { tls_signal = {object, kind}; } void ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceObserved)(const void* object, ObjectKind kind) { tls_observed = {object, kind}; } } // extern "C" TEST(NotificationTest, TracesNotify) { Notification n; tls_signal = {}; n.Notify(); EXPECT_EQ(tls_signal, TraceRecord(&n, ObjectKind::kNotification)); } TEST(NotificationTest, TracesWaitForNotification) { Notification n; n.Notify(); tls_wait = tls_continue = {}; n.WaitForNotification(); EXPECT_EQ(tls_wait, TraceRecord(&n, ObjectKind::kNotification)); EXPECT_EQ(tls_continue, TraceRecord(&n, ObjectKind::kNotification)); } TEST(NotificationTest, TracesWaitForNotificationWithTimeout) { Notification n; tls_wait = tls_continue = {}; n.WaitForNotificationWithTimeout(absl::Milliseconds(1)); EXPECT_EQ(tls_wait, TraceRecord(&n, ObjectKind::kNotification)); EXPECT_EQ(tls_continue, TraceRecord(nullptr, ObjectKind::kNotification)); n.Notify(); tls_wait = tls_continue = {}; n.WaitForNotificationWithTimeout(absl::Milliseconds(1)); EXPECT_EQ(tls_wait, TraceRecord(&n, ObjectKind::kNotification)); EXPECT_EQ(tls_continue, TraceRecord(&n, ObjectKind::kNotification)); } TEST(NotificationTest, TracesHasBeenNotified) { Notification n; tls_observed = {}; ASSERT_FALSE(n.HasBeenNotified()); EXPECT_EQ(tls_observed, TraceRecord(nullptr, ObjectKind::kUnknown)); n.Notify(); tls_observed = {}; ASSERT_TRUE(n.HasBeenNotified()); EXPECT_EQ(tls_observed, TraceRecord(&n, ObjectKind::kNotification)); } } // namespace base_internal #endif // ABSL_HAVE_ATTRIBUTE_WEAK ABSL_NAMESPACE_END } // namespace absl