Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,14 @@ phi::Allocation *VirtualMemoryAutoGrowthBestFitAllocator::AllocateImpl(
void VirtualMemoryAutoGrowthBestFitAllocator::FreeImpl(
phi::Allocation *allocation) {
std::lock_guard<SpinLock> guard(spinlock_);
auto block_it = static_cast<BlockAllocation *>(allocation)->block_it_;
void *ptr = allocation->ptr();
auto block_it = FindBlockByPtr(ptr);
if (block_it == all_blocks_.end()) {
VLOG(4) << "[VMM][FreeImplMissingBlock] ptr=" << ptr
<< " allocation_size=" << allocation->size();
delete allocation;
return;
}
TryMergeBlock2Blocks(block_it);
delete allocation;
}
Expand All @@ -160,6 +167,14 @@ bool VirtualMemoryAutoGrowthBestFitAllocator::CollectTensorParts(
return false;
}

std::list<Block>::iterator
VirtualMemoryAutoGrowthBestFitAllocator::FindBlockByPtr(void *ptr) {
for (auto it = all_blocks_.begin(); it != all_blocks_.end(); ++it) {
if (it->ptr_ == ptr) return it;
}
return all_blocks_.end();
}

void VirtualMemoryAutoGrowthBestFitAllocator::TryMergeBlock2Blocks(
std::list<Block>::iterator block) {
if (block->ptr_ == all_blocks_.front().ptr_ &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class VirtualMemoryAutoGrowthBestFitAllocator : public Allocator {
void ExtendOrCompact(size_t size);
void TryMergeBlock2Blocks(std::list<Block>::iterator iter);
void DumpInfo(std::string phase) const;
std::list<Block>::iterator FindBlockByPtr(void *ptr);

std::shared_ptr<Allocator> underlying_allocator_;
std::unique_ptr<MemoryCompactionStrategy> memory_compactor_;
Expand Down
57 changes: 57 additions & 0 deletions test/cpp/fluid/memory/vmm_auto_growth_best_fit_allocator_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
// limitations under the License.

#include "paddle/phi/core/memory/allocation/cuda_virtual_mem_allocator.h"
// Expose internals for white-box testing.
#define private public
#include "paddle/phi/core/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.h"
#undef private

#include "gtest/gtest.h"
#include "paddle/common/errors.h"
#include "paddle/phi/core/memory/memory.h"

namespace paddle {
Expand All @@ -38,6 +42,59 @@ TEST(test_vmm_allocator, test_mem_stats) {
EXPECT_EQ(DeviceMemoryStatCurrentValue("Reserved", 0), 0);
}

class DummyAllocator : public Allocator {
public:
bool IsAllocThreadSafe() const override { return true; }

protected:
phi::Allocation* AllocateImpl(size_t) override {
PADDLE_THROW(common::errors::Unavailable(
"DummyAllocator::AllocateImpl should not be called."));
}
void FreeImpl(phi::Allocation*) override {}
};

// Expose FreeImpl for testing.
class ExposedVmmAllocator : public VirtualMemoryAutoGrowthBestFitAllocator {
public:
using VirtualMemoryAutoGrowthBestFitAllocator::FreeImpl;
using VirtualMemoryAutoGrowthBestFitAllocator::
VirtualMemoryAutoGrowthBestFitAllocator;
};

TEST(test_vmm_allocator, free_impl_handles_stale_iterator) {
auto underlying = std::make_shared<DummyAllocator>();
phi::GPUPlace place(0);
ExposedVmmAllocator allocator(underlying, 256, place);

// Manually construct blocks: [free-prev][used-target][free-next]
allocator.all_blocks_.clear();
auto prev = allocator.all_blocks_.emplace(
allocator.all_blocks_.end(), reinterpret_cast<void*>(0x1000), 1024, true);
auto target = allocator.all_blocks_.emplace(allocator.all_blocks_.end(),
reinterpret_cast<void*>(0x1400),
2048,
false);
auto next = allocator.all_blocks_.emplace(
allocator.all_blocks_.end(), reinterpret_cast<void*>(0x1C00), 4096, true);

allocator.free_blocks_.clear();
allocator.free_blocks_.emplace(std::make_pair(prev->size_, prev->ptr_), prev);
allocator.free_blocks_.emplace(std::make_pair(next->size_, next->ptr_), next);

// Stale allocation keeps an iterator to the "target" block, which will be
// erased before calling FreeImpl to simulate a dangling iterator.
auto stale_allocation = new BlockAllocation(target, place);

// Invalidate the iterator by erasing the target block (simulating previous
// merge/erase). free_blocks_ deliberately not updated to mimic the
// inconsistent state seen in the crash reports.
allocator.all_blocks_.erase(target);

// FreeImpl should not crash; it should detect the missing block and return.
EXPECT_NO_THROW(allocator.FreeImpl(stale_allocation));
}

} // namespace allocation
} // namespace memory
} // namespace paddle
Loading