diff --git a/libresolve/src/remediate.rs b/libresolve/src/remediate.rs index b9d384fd..30c6b9c3 100644 --- a/libresolve/src/remediate.rs +++ b/libresolve/src/remediate.rs @@ -33,17 +33,15 @@ pub extern "C" fn resolve_stack_obj(ptr: *mut c_void, size: usize) -> () { } #[unsafe(no_mangle)] -pub extern "C" fn resolve_invalidate_stack(base: *mut c_void, limit: *mut c_void) { +pub extern "C" fn resolve_invalidate_stack(base: *mut c_void) { let base = base as Vaddr; - let limit = limit as Vaddr; { let mut obj_list = ALIVE_OBJ_LIST.lock(); - // TODO: Add these to a free list? - obj_list.invalidate_region(base, limit); + obj_list.invalidate_at(base); } - info!("[STACK] Free range 0x{base:x}..=0x{limit:x}"); + info!("[STACK] Free addr 0x{base:x}"); } /** diff --git a/libresolve/src/shadowobjs.rs b/libresolve/src/shadowobjs.rs index a7e3b7fe..878a47d2 100644 --- a/libresolve/src/shadowobjs.rs +++ b/libresolve/src/shadowobjs.rs @@ -94,6 +94,7 @@ impl ShadowObjectTable { } /// Removes any allocation with a base address within the supplied region + #[allow(dead_code)] pub fn invalidate_region(&mut self, base: Vaddr, limit: Vaddr) { self.table .extract_if(base..=limit, |_, _| true) diff --git a/llvm-plugin/src/CVEAssert/bounds_check.cpp b/llvm-plugin/src/CVEAssert/bounds_check.cpp index 87f056c5..a8ec55b6 100644 --- a/llvm-plugin/src/CVEAssert/bounds_check.cpp +++ b/llvm-plugin/src/CVEAssert/bounds_check.cpp @@ -303,37 +303,60 @@ void instrumentAlloca(Function *F) { auto void_ty = Type::getVoidTy(Ctx); // Initialize list to store pointers to alloca and instructions - std::vector allocaList; + std::vector toFreeList; + + auto invalidateFn = M->getOrInsertFunction( + "resolve_invalidate_stack", + FunctionType::get(void_ty, { ptr_ty }, false) + ); + + auto handle_alloca = [&](auto* allocaInst) { + bool hasStart = false; + bool hasEnd = false; + + Type *allocatedType = allocaInst->getAllocatedType(); + uint64_t typeSize = DL.getTypeAllocSize(allocatedType); + + for (auto* user: allocaInst->users()) { + if( auto* call = dyn_cast(user)) { + auto called = call->getCalledFunction(); + if (called && called->getName().starts_with("llvm.lifetime.start")) { + hasStart = true; + builder.SetInsertPoint(call->getNextNode()); + builder.CreateCall(getResolveStackObj(M), { allocaInst, ConstantInt::get(size_ty, typeSize)}); + } + + if (called && called->getName().starts_with("llvm.lifetime.end")) { + hasEnd = true; + builder.SetInsertPoint(call->getNextNode()); + builder.CreateCall(invalidateFn, { allocaInst}); + } + } + } + + // This is probably always true unless we are given malformed input. + assert(hasStart == hasEnd); + if (hasStart) { return; } + // Otherwise Insert after the alloca instruction + builder.SetInsertPoint(allocaInst->getNextNode()); + builder.CreateCall(getResolveStackObj(M), { allocaInst, ConstantInt::get(size_ty, typeSize)}); + // If we have not added an invalidate call already make sure we do so later. + toFreeList.push_back(allocaInst); + }; for (auto &BB: *F) { for (auto &instr: BB) { if (auto *inst = dyn_cast(&instr)) { - allocaList.push_back(inst); + handle_alloca(inst); } } } - for (auto* allocaInst: allocaList) { - // Insert after the alloca instruction - builder.SetInsertPoint(allocaInst->getNextNode()); - Value* allocatedPtr = allocaInst; - Value *sizeVal = nullptr; - Type *allocatedType = allocaInst->getAllocatedType(); - uint64_t typeSize = DL.getTypeAllocSize(allocatedType); - sizeVal = ConstantInt::get(size_ty, typeSize); - builder.CreateCall(getResolveStackObj(M), { allocatedPtr, sizeVal }); - } - // Find low and high allocations and pass to resolve_invaliate_stack - if (allocaList.empty()) { + if (toFreeList.empty()) { return; } - auto invalidateFn = M->getOrInsertFunction( - "resolve_invalidate_stack", - FunctionType::get(void_ty, { ptr_ty, ptr_ty }, false) - ); - // Stack grows down, so first allocation is high, last is low // Hmm.. compiler seems to be reordering the allocas in ways // that break this assumption @@ -344,8 +367,8 @@ void instrumentAlloca(Function *F) { if (auto *inst = dyn_cast(&instr)) { builder.SetInsertPoint(inst); // builder.CreateCall(invalidateFn, { low, high }); - for (auto *alloca: allocaList) { - builder.CreateCall(invalidateFn, { alloca, alloca }); + for (auto *alloca: toFreeList) { + builder.CreateCall(invalidateFn, { alloca }); } } } @@ -631,4 +654,4 @@ void sanitizeMemInstBounds(Function *F, Vulnerability::RemediationStrategies str instrumentGEP(F); sanitizeMemcpy(F, strategy); sanitizeLoadStore(F, strategy); -} \ No newline at end of file +}