diff --git a/paddle/phi/core/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.cc b/paddle/phi/core/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.cc index a0448c77d5d0f0..e6104cdcc03467 100644 --- a/paddle/phi/core/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.cc +++ b/paddle/phi/core/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.cc @@ -141,7 +141,14 @@ phi::Allocation *VirtualMemoryAutoGrowthBestFitAllocator::AllocateImpl( void VirtualMemoryAutoGrowthBestFitAllocator::FreeImpl( phi::Allocation *allocation) { std::lock_guard guard(spinlock_); - auto block_it = static_cast(allocation)->block_it_; + void *ptr = allocation->ptr(); + auto block_it = FindBlockByPtr(ptr); + if (block_it == all_blocks_.end()) { + VLOG(4) << "[VMM][FreeImplMissingBlock] ptr=" << ptr + << " allocation_size=" << allocation->size(); + delete allocation; + return; + } TryMergeBlock2Blocks(block_it); delete allocation; } @@ -160,6 +167,14 @@ bool VirtualMemoryAutoGrowthBestFitAllocator::CollectTensorParts( return false; } +std::list::iterator +VirtualMemoryAutoGrowthBestFitAllocator::FindBlockByPtr(void *ptr) { + for (auto it = all_blocks_.begin(); it != all_blocks_.end(); ++it) { + if (it->ptr_ == ptr) return it; + } + return all_blocks_.end(); +} + void VirtualMemoryAutoGrowthBestFitAllocator::TryMergeBlock2Blocks( std::list::iterator block) { if (block->ptr_ == all_blocks_.front().ptr_ && diff --git a/paddle/phi/core/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.h b/paddle/phi/core/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.h index 6b80b17bcedaa2..93490cbe46eea6 100644 --- a/paddle/phi/core/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.h +++ b/paddle/phi/core/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.h @@ -83,6 +83,7 @@ class VirtualMemoryAutoGrowthBestFitAllocator : public Allocator { void ExtendOrCompact(size_t size); void TryMergeBlock2Blocks(std::list::iterator iter); void DumpInfo(std::string phase) const; + std::list::iterator FindBlockByPtr(void *ptr); std::shared_ptr underlying_allocator_; std::unique_ptr memory_compactor_; diff --git a/test/cpp/fluid/memory/vmm_auto_growth_best_fit_allocator_test.cu b/test/cpp/fluid/memory/vmm_auto_growth_best_fit_allocator_test.cu index f3e74a6e179339..70dfa14b5eb637 100644 --- a/test/cpp/fluid/memory/vmm_auto_growth_best_fit_allocator_test.cu +++ b/test/cpp/fluid/memory/vmm_auto_growth_best_fit_allocator_test.cu @@ -13,9 +13,13 @@ // limitations under the License. #include "paddle/phi/core/memory/allocation/cuda_virtual_mem_allocator.h" +// Expose internals for white-box testing. +#define private public #include "paddle/phi/core/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.h" +#undef private #include "gtest/gtest.h" +#include "paddle/common/errors.h" #include "paddle/phi/core/memory/memory.h" namespace paddle { @@ -38,6 +42,59 @@ TEST(test_vmm_allocator, test_mem_stats) { EXPECT_EQ(DeviceMemoryStatCurrentValue("Reserved", 0), 0); } +class DummyAllocator : public Allocator { + public: + bool IsAllocThreadSafe() const override { return true; } + + protected: + phi::Allocation* AllocateImpl(size_t) override { + PADDLE_THROW(common::errors::Unavailable( + "DummyAllocator::AllocateImpl should not be called.")); + } + void FreeImpl(phi::Allocation*) override {} +}; + +// Expose FreeImpl for testing. +class ExposedVmmAllocator : public VirtualMemoryAutoGrowthBestFitAllocator { + public: + using VirtualMemoryAutoGrowthBestFitAllocator::FreeImpl; + using VirtualMemoryAutoGrowthBestFitAllocator:: + VirtualMemoryAutoGrowthBestFitAllocator; +}; + +TEST(test_vmm_allocator, free_impl_handles_stale_iterator) { + auto underlying = std::make_shared(); + phi::GPUPlace place(0); + ExposedVmmAllocator allocator(underlying, 256, place); + + // Manually construct blocks: [free-prev][used-target][free-next] + allocator.all_blocks_.clear(); + auto prev = allocator.all_blocks_.emplace( + allocator.all_blocks_.end(), reinterpret_cast(0x1000), 1024, true); + auto target = allocator.all_blocks_.emplace(allocator.all_blocks_.end(), + reinterpret_cast(0x1400), + 2048, + false); + auto next = allocator.all_blocks_.emplace( + allocator.all_blocks_.end(), reinterpret_cast(0x1C00), 4096, true); + + allocator.free_blocks_.clear(); + allocator.free_blocks_.emplace(std::make_pair(prev->size_, prev->ptr_), prev); + allocator.free_blocks_.emplace(std::make_pair(next->size_, next->ptr_), next); + + // Stale allocation keeps an iterator to the "target" block, which will be + // erased before calling FreeImpl to simulate a dangling iterator. + auto stale_allocation = new BlockAllocation(target, place); + + // Invalidate the iterator by erasing the target block (simulating previous + // merge/erase). free_blocks_ deliberately not updated to mimic the + // inconsistent state seen in the crash reports. + allocator.all_blocks_.erase(target); + + // FreeImpl should not crash; it should detect the missing block and return. + EXPECT_NO_THROW(allocator.FreeImpl(stale_allocation)); +} + } // namespace allocation } // namespace memory } // namespace paddle