Skip to content

Commit

Permalink
Simplify arbitration participant lock
Browse files Browse the repository at this point in the history
  • Loading branch information
tanjialiang committed Nov 15, 2024
1 parent 99979c4 commit 3322005
Show file tree
Hide file tree
Showing 13 changed files with 67 additions and 242 deletions.
51 changes: 14 additions & 37 deletions velox/common/memory/ArbitrationParticipant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ uint64_t ArbitrationParticipant::reclaim(
if (targetBytes == 0) {
return 0;
}
ArbitrationOperationTimedLock l(reclaimMutex_);
ArbitrationTimedLock l(reclaimMutex_, maxWaitTimeNs);
TestValue::adjust(
"facebook::velox::memory::ArbitrationParticipant::reclaim", this);
uint64_t reclaimedBytes{0};
Expand Down Expand Up @@ -320,7 +320,7 @@ uint64_t ArbitrationParticipant::shrinkLocked(bool reclaimAll) {

uint64_t ArbitrationParticipant::abort(
const std::exception_ptr& error) noexcept {
ArbitrationOperationTimedLock l(reclaimMutex_);
std::lock_guard<std::timed_mutex> l(reclaimMutex_);
return abortLocked(error);
}

Expand Down Expand Up @@ -403,52 +403,29 @@ std::string ArbitrationCandidate::toString() const {
}

#ifdef TSAN_BUILD
ArbitrationOperationTimedLock::ArbitrationOperationTimedLock(
std::timed_mutex& mutex)
ArbitrationTimedLock::ArbitrationTimedLock(
std::timed_mutex& mutex,
uint64_t /* unused */)
: mutex_(mutex) {
mutex_.lock();
}

ArbitrationOperationTimedLock::~ArbitrationOperationTimedLock() {
ArbitrationTimedLock::~ArbitrationTimedLock() {
mutex_.unlock();
}
#else
ArbitrationOperationTimedLock::ArbitrationOperationTimedLock(
std::timed_mutex& mutex) {
auto arbitrationContext = memoryArbitrationContext();
if (arbitrationContext == nullptr) {
std::unique_lock<std::timed_mutex> l(mutex);
timedLock_ = std::move(l);
return;
}
auto* operation = arbitrationContext->op;
if (operation == nullptr) {
VELOX_CHECK_EQ(
MemoryArbitrationContext::typeName(arbitrationContext->type),
MemoryArbitrationContext::typeName(
MemoryArbitrationContext::Type::kGlobal));
std::unique_lock<std::timed_mutex> l(mutex);
timedLock_ = std::move(l);
return;
}
VELOX_CHECK_EQ(
MemoryArbitrationContext::typeName(arbitrationContext->type),
MemoryArbitrationContext::typeName(
MemoryArbitrationContext::Type::kLocal));
std::unique_lock<std::timed_mutex> l(
mutex, std::chrono::nanoseconds(operation->timeoutNs()));
timedLock_ = std::move(l);
if (!timedLock_.owns_lock()) {
ArbitrationTimedLock::ArbitrationTimedLock(
std::timed_mutex& mutex,
uint64_t timeoutNs)
: mutex_(mutex) {
if (!mutex_.try_lock_for(std::chrono::nanoseconds(timeoutNs))) {
VELOX_MEM_ARBITRATION_TIMEOUT(fmt::format(
"Memory arbitration lock timed out on memory pool: {} after running {}",
operation->participant()->name(),
succinctNanos(operation->executionTimeNs())));
"Memory arbitration lock timed out when reclaiming from arbitration participant."));
}
}

ArbitrationOperationTimedLock::~ArbitrationOperationTimedLock() {
VELOX_CHECK(timedLock_.owns_lock());
timedLock_.unlock();
ArbitrationTimedLock::~ArbitrationTimedLock() {
mutex_.unlock();
}
#endif
} // namespace facebook::velox::memory
18 changes: 4 additions & 14 deletions velox/common/memory/ArbitrationParticipant.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,26 +49,16 @@ class ScopedArbitrationParticipant;
/// automatically be applied.
///
/// NOTE: TSAN is incompatible with std::timed_mutex when used with timeout. So
/// in TSAN build a trivial implementation is implemented.
#ifdef TSAN_BUILD
class ArbitrationOperationTimedLock {
/// in TSAN build a trivial lock is implemented.
class ArbitrationTimedLock {
public:
explicit ArbitrationOperationTimedLock(std::timed_mutex& mutex);
~ArbitrationOperationTimedLock();
ArbitrationTimedLock(std::timed_mutex& mutex, uint64_t timeoutNs);
~ArbitrationTimedLock();

private:
std::timed_mutex& mutex_;
};
#else
class ArbitrationOperationTimedLock {
public:
explicit ArbitrationOperationTimedLock(std::timed_mutex& mutex);
~ArbitrationOperationTimedLock();

private:
std::unique_lock<std::timed_mutex> timedLock_;
};
#endif
/// Manages the memory arbitration operations on a query memory pool. It also
/// tracks the arbitration stats during the query memory pool's lifecycle.
class ArbitrationParticipant
Expand Down
14 changes: 4 additions & 10 deletions velox/common/memory/MemoryArbitrator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -449,12 +449,8 @@ bool MemoryArbitrator::Stats::operator<=(const Stats& other) const {
return !(*this > other);
}

MemoryArbitrationContext::MemoryArbitrationContext(
const MemoryPool* requestor,
ArbitrationOperation* _op)
: type(Type::kLocal), requestorName(requestor->name()), op(_op) {
VELOX_CHECK_NOT_NULL(op);
}
MemoryArbitrationContext::MemoryArbitrationContext(const MemoryPool* requestor)
: type(Type::kLocal), requestorName(requestor->name()) {}

std::string MemoryArbitrationContext::typeName(
MemoryArbitrationContext::Type type) {
Expand All @@ -469,10 +465,8 @@ std::string MemoryArbitrationContext::typeName(
}

ScopedMemoryArbitrationContext::ScopedMemoryArbitrationContext(
const MemoryPool* requestor,
ArbitrationOperation* op)
: savedArbitrationCtx_(arbitrationCtx),
currentArbitrationCtx_(requestor, op) {
const MemoryPool* requestor)
: savedArbitrationCtx_(arbitrationCtx), currentArbitrationCtx_(requestor) {
arbitrationCtx = &currentArbitrationCtx_;
}

Expand Down
12 changes: 3 additions & 9 deletions velox/common/memory/MemoryArbitrator.h
Original file line number Diff line number Diff line change
Expand Up @@ -421,13 +421,9 @@ struct MemoryArbitrationContext {
/// global memory arbitration type.
const std::string requestorName;

ArbitrationOperation* const op;
explicit MemoryArbitrationContext(const MemoryPool* requestor);

MemoryArbitrationContext(
const MemoryPool* requestor,
ArbitrationOperation* _op);

MemoryArbitrationContext() : type(Type::kGlobal), op(nullptr) {}
MemoryArbitrationContext() : type(Type::kGlobal) {}
};

/// Object used to set/restore the memory arbitration context when a thread is
Expand All @@ -439,9 +435,7 @@ class ScopedMemoryArbitrationContext {
explicit ScopedMemoryArbitrationContext(
const MemoryArbitrationContext* context);

ScopedMemoryArbitrationContext(
const MemoryPool* requestor,
ArbitrationOperation* op);
explicit ScopedMemoryArbitrationContext(const MemoryPool* requestor);

~ScopedMemoryArbitrationContext();

Expand Down
2 changes: 1 addition & 1 deletion velox/common/memory/SharedArbitrator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1341,7 +1341,7 @@ SharedArbitrator::ScopedArbitration::ScopedArbitration(
ArbitrationOperation* operation)
: arbitrator_(arbitrator),
operation_(operation),
arbitrationCtx_(operation->participant()->pool(), operation),
arbitrationCtx_(operation->participant()->pool()),
startTime_(std::chrono::steady_clock::now()) {
VELOX_CHECK_NOT_NULL(arbitrator_);
VELOX_CHECK_NOT_NULL(operation_);
Expand Down
66 changes: 16 additions & 50 deletions velox/common/memory/tests/ArbitrationParticipantTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1419,7 +1419,7 @@ DEBUG_ONLY_TEST_F(ArbitrationParticipantTest, reclaimLock) {
folly::EventCount reclaim1CompletedWait;
std::thread reclaimThread1([&]() {
memory::MemoryReclaimer::Stats stats;
ASSERT_EQ(scopedParticipant->reclaim(MB, 1'000'000, stats), 0);
ASSERT_EQ(scopedParticipant->reclaim(MB, 1'000'000'000'000, stats), 0);
ASSERT_EQ(stats.numNonReclaimableAttempts, 0);
reclaim1CompletedFlag = true;
reclaim1CompletedWait.notifyAll();
Expand Down Expand Up @@ -1454,7 +1454,7 @@ DEBUG_ONLY_TEST_F(ArbitrationParticipantTest, reclaimLock) {
folly::EventCount reclaim2CompletedWait;
std::thread reclaimThread2([&]() {
memory::MemoryReclaimer::Stats stats;
ASSERT_EQ(scopedParticipant->reclaim(MB, 1'000'000, stats), 0);
ASSERT_EQ(scopedParticipant->reclaim(MB, 1'000'000'000'000, stats), 0);
ASSERT_EQ(stats.numNonReclaimableAttempts, 0);
reclaim2CompletedFlag = true;
reclaim2CompletedWait.notifyAll();
Expand Down Expand Up @@ -1896,65 +1896,31 @@ TEST_F(ArbitrationParticipantTest, arbitrationOperationTimedLock) {
};

struct TestData {
std::string type;
uint64_t lockHoldTimeNs;
uint64_t opTimeoutNs;
};

std::timed_mutex mutex;
std::vector<TestData> testDataVec{
{"local", 1'000'000'000UL, 2'000'000'000UL},
{"local", 2'000'000'000UL, 1'000'000'000UL},
{"global", 1'000'000'000UL, 2'000'000'000UL},
{"global", 2'000'000'000UL, 1'000'000'000UL},
{"none", 1'000'000'000UL, 2'000'000'000UL}};
{1'000'000'000UL, 2'000'000'000UL}, {2'000'000'000UL, 1'000'000'000UL}};

for (auto& testData : testDataVec) {
ScopedArbitrationParticipant scopedArbitrationParticipant(
participant, participantPool);
ArbitrationOperation operation(
std::move(scopedArbitrationParticipant), 1024, testData.opTimeoutNs);
if (testData.type == "local") {
MemoryArbitrationContext ctx(participantPool.get(), &operation);
ScopedMemoryArbitrationContext scopedCtx(&ctx);

folly::EventCount lockWait;
std::atomic_bool lockWaitFlag{true};
auto lockHolder = createLockHolderThread(
mutex, testData.lockHoldTimeNs, lockWait, lockWaitFlag);
std::unique_ptr<ArbitrationOperationTimedLock> timedLock{nullptr};
lockWait.await([&]() { return !lockWaitFlag.load(); });
if (testData.lockHoldTimeNs < testData.opTimeoutNs) {
timedLock = std::make_unique<ArbitrationOperationTimedLock>(mutex);
ASSERT_FALSE(mutex.try_lock());
} else {
VELOX_ASSERT_THROW(
std::make_unique<ArbitrationOperationTimedLock>(mutex),
"Memory arbitration lock timed out");
}
lockHolder.join();
} else if (testData.type == "global") {
MemoryArbitrationContext ctx;
ScopedMemoryArbitrationContext scopedCtx(&ctx);

folly::EventCount lockWait;
std::atomic_bool lockWaitFlag{true};
auto lockHolder = createLockHolderThread(
mutex, testData.lockHoldTimeNs, lockWait, lockWaitFlag);
lockWait.await([&]() { return !lockWaitFlag.load(); });
ArbitrationOperationTimedLock timedLock(mutex);
folly::EventCount lockWait;
std::atomic_bool lockWaitFlag{true};
auto lockHolder = createLockHolderThread(
mutex, testData.lockHoldTimeNs, lockWait, lockWaitFlag);
std::unique_ptr<ArbitrationTimedLock> timedLock{nullptr};
lockWait.await([&]() { return !lockWaitFlag.load(); });
if (testData.lockHoldTimeNs < testData.opTimeoutNs) {
timedLock =
std::make_unique<ArbitrationTimedLock>(mutex, testData.opTimeoutNs);
ASSERT_FALSE(mutex.try_lock());
lockHolder.join();
} else {
folly::EventCount lockWait;
std::atomic_bool lockWaitFlag{true};
auto lockHolder = createLockHolderThread(
mutex, testData.lockHoldTimeNs, lockWait, lockWaitFlag);
lockWait.await([&]() { return !lockWaitFlag.load(); });
ArbitrationOperationTimedLock timedLock(mutex);
ASSERT_FALSE(mutex.try_lock());
lockHolder.join();
VELOX_ASSERT_THROW(
std::make_unique<ArbitrationTimedLock>(mutex, testData.opTimeoutNs),
"Memory arbitration lock timed out");
}
lockHolder.join();
}
}
#endif
Expand Down
22 changes: 4 additions & 18 deletions velox/common/memory/tests/MemoryArbitratorTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -990,41 +990,27 @@ TEST_F(MemoryReclaimerTest, arbitrationContext) {
ASSERT_FALSE(isSpillMemoryPool(leafChild2.get()));
ASSERT_TRUE(memoryArbitrationContext() == nullptr);
{
auto arbitrationStructs =
test::ArbitrationTestStructs::createArbitrationTestStructs(leafChild1);
ScopedMemoryArbitrationContext arbitrationContext(
leafChild1.get(), arbitrationStructs.operation.get());
ScopedMemoryArbitrationContext arbitrationContext(leafChild1.get());
ASSERT_TRUE(memoryArbitrationContext() != nullptr);
ASSERT_EQ(memoryArbitrationContext()->requestorName, leafChild1->name());
}
ASSERT_TRUE(memoryArbitrationContext() == nullptr);
{
auto arbitrationStructs =
test::ArbitrationTestStructs::createArbitrationTestStructs(leafChild2);
ScopedMemoryArbitrationContext arbitrationContext(
leafChild2.get(), arbitrationStructs.operation.get());
ScopedMemoryArbitrationContext arbitrationContext(leafChild2.get());
ASSERT_TRUE(memoryArbitrationContext() != nullptr);
ASSERT_EQ(memoryArbitrationContext()->requestorName, leafChild2->name());
}
ASSERT_TRUE(memoryArbitrationContext() == nullptr);
std::thread nonAbitrationThread([&]() {
ASSERT_TRUE(memoryArbitrationContext() == nullptr);
{
auto arbitrationStructs =
test::ArbitrationTestStructs::createArbitrationTestStructs(
leafChild1);
ScopedMemoryArbitrationContext arbitrationContext(
leafChild1.get(), arbitrationStructs.operation.get());
ScopedMemoryArbitrationContext arbitrationContext(leafChild1.get());
ASSERT_TRUE(memoryArbitrationContext() != nullptr);
ASSERT_EQ(memoryArbitrationContext()->requestorName, leafChild1->name());
}
ASSERT_TRUE(memoryArbitrationContext() == nullptr);
{
auto arbitrationStructs =
test::ArbitrationTestStructs::createArbitrationTestStructs(
leafChild2);
ScopedMemoryArbitrationContext arbitrationContext(
leafChild2.get(), arbitrationStructs.operation.get());
ScopedMemoryArbitrationContext arbitrationContext(leafChild2.get());
ASSERT_TRUE(memoryArbitrationContext() != nullptr);
ASSERT_EQ(memoryArbitrationContext()->requestorName, leafChild2->name());
}
Expand Down
5 changes: 1 addition & 4 deletions velox/common/memory/tests/MemoryPoolTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3888,10 +3888,7 @@ TEST_P(MemoryPoolTest, overuseUnderArbitration) {
ASSERT_FALSE(child->maybeReserve(2 * kMaxSize));
ASSERT_EQ(child->usedBytes(), 0);
ASSERT_EQ(child->reservedBytes(), 0);
auto arbitrationTestStructs =
test::ArbitrationTestStructs::createArbitrationTestStructs(root);
ScopedMemoryArbitrationContext scopedMemoryArbitration(
root.get(), arbitrationTestStructs.operation.get());
ScopedMemoryArbitrationContext scopedMemoryArbitration(root.get());
ASSERT_TRUE(underMemoryArbitration());
ASSERT_TRUE(child->maybeReserve(2 * kMaxSize));
ASSERT_EQ(child->usedBytes(), 0);
Expand Down
24 changes: 4 additions & 20 deletions velox/dwio/dwrf/test/E2EWriterTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1735,11 +1735,7 @@ DEBUG_ONLY_TEST_F(E2EWriterTest, memoryReclaimOnWrite) {
const auto oldReservedBytes = writerPool->reservedBytes();
const auto oldUsedBytes = writerPool->usedBytes();
{
auto arbitrationStructs =
memory::test::ArbitrationTestStructs::createArbitrationTestStructs(
writerPool);
memory::ScopedMemoryArbitrationContext arbitrationCtx(
writerPool.get(), arbitrationStructs.operation.get());
memory::ScopedMemoryArbitrationContext arbitrationCtx(writerPool.get());
writerPool->reclaim(1L << 30, 0, stats);
}
ASSERT_EQ(stats.numNonReclaimableAttempts, 0);
Expand Down Expand Up @@ -1778,11 +1774,7 @@ DEBUG_ONLY_TEST_F(E2EWriterTest, memoryReclaimOnWrite) {
writer->testingNonReclaimableSection() = false;
stats.numNonReclaimableAttempts = 0;
{
auto arbitrationStructs =
memory::test::ArbitrationTestStructs::createArbitrationTestStructs(
writerPool);
memory::ScopedMemoryArbitrationContext arbitrationCtx(
writerPool.get(), arbitrationStructs.operation.get());
memory::ScopedMemoryArbitrationContext arbitrationCtx(writerPool.get());
const auto reclaimedBytes = writerPool->reclaim(1L << 30, 0, stats);
ASSERT_GT(reclaimedBytes, 0);
}
Expand Down Expand Up @@ -2124,11 +2116,7 @@ DEBUG_ONLY_TEST_F(E2EWriterTest, memoryReclaimThreshold) {
*writerPool, reclaimableBytes));
ASSERT_GT(reclaimableBytes, 0);
{
auto arbitrationStructs =
memory::test::ArbitrationTestStructs::createArbitrationTestStructs(
writerPool);
memory::ScopedMemoryArbitrationContext arbitrationCtx(
writerPool.get(), arbitrationStructs.operation.get());
memory::ScopedMemoryArbitrationContext arbitrationCtx(writerPool.get());
ASSERT_GT(writerPool->reclaim(1L << 30, 0, stats), 0);
}
ASSERT_GT(stats.reclaimExecTimeUs, 0);
Expand All @@ -2138,11 +2126,7 @@ DEBUG_ONLY_TEST_F(E2EWriterTest, memoryReclaimThreshold) {
*writerPool, reclaimableBytes));
ASSERT_EQ(reclaimableBytes, 0);
{
auto arbitrationStructs =
memory::test::ArbitrationTestStructs::createArbitrationTestStructs(
writerPool);
memory::ScopedMemoryArbitrationContext arbitrationCtx(
writerPool.get(), arbitrationStructs.operation.get());
memory::ScopedMemoryArbitrationContext arbitrationCtx(writerPool.get());
ASSERT_EQ(writerPool->reclaim(1L << 30, 0, stats), 0);
}
ASSERT_EQ(stats.numNonReclaimableAttempts, 0);
Expand Down
Loading

0 comments on commit 3322005

Please sign in to comment.