Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions libresolve/src/remediate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,15 @@ pub extern "C" fn resolve_stack_obj(ptr: *mut c_void, size: usize) -> () {
}

#[unsafe(no_mangle)]
pub extern "C" fn resolve_invalidate_stack(base: *mut c_void, limit: *mut c_void) {
pub extern "C" fn resolve_invalidate_stack(base: *mut c_void) {
let base = base as Vaddr;
let limit = limit as Vaddr;

{
let mut obj_list = ALIVE_OBJ_LIST.lock();
// TODO: Add these to a free list?
obj_list.invalidate_region(base, limit);
obj_list.invalidate_at(base);
}

info!("[STACK] Free range 0x{base:x}..=0x{limit:x}");
info!("[STACK] Free addr 0x{base:x}");
}

/**
Expand Down
1 change: 1 addition & 0 deletions libresolve/src/shadowobjs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ impl ShadowObjectTable {
}

/// Removes any allocation with a base address within the supplied region
#[allow(dead_code)]
pub fn invalidate_region(&mut self, base: Vaddr, limit: Vaddr) {
self.table
.extract_if(base..=limit, |_, _| true)
Expand Down
67 changes: 45 additions & 22 deletions llvm-plugin/src/CVEAssert/bounds_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,37 +303,60 @@ void instrumentAlloca(Function *F) {
auto void_ty = Type::getVoidTy(Ctx);

// Initialize list to store pointers to alloca and instructions
std::vector<AllocaInst *> allocaList;
std::vector<AllocaInst *> toFreeList;

auto invalidateFn = M->getOrInsertFunction(
"resolve_invalidate_stack",
FunctionType::get(void_ty, { ptr_ty }, false)
);

auto handle_alloca = [&](auto* allocaInst) {
bool hasStart = false;
bool hasEnd = false;

Type *allocatedType = allocaInst->getAllocatedType();
uint64_t typeSize = DL.getTypeAllocSize(allocatedType);

for (auto* user: allocaInst->users()) {
if( auto* call = dyn_cast<CallInst>(user)) {
auto called = call->getCalledFunction();
if (called && called->getName().starts_with("llvm.lifetime.start")) {
hasStart = true;
builder.SetInsertPoint(call->getNextNode());
builder.CreateCall(getResolveStackObj(M), { allocaInst, ConstantInt::get(size_ty, typeSize)});
}

if (called && called->getName().starts_with("llvm.lifetime.end")) {
hasEnd = true;
builder.SetInsertPoint(call->getNextNode());
builder.CreateCall(invalidateFn, { allocaInst});
}
}
}

// This is probably always true unless we are given malformed input.
assert(hasStart == hasEnd);
if (hasStart) { return; }
// Otherwise Insert after the alloca instruction
builder.SetInsertPoint(allocaInst->getNextNode());
builder.CreateCall(getResolveStackObj(M), { allocaInst, ConstantInt::get(size_ty, typeSize)});
// If we have not added an invalidate call already make sure we do so later.
toFreeList.push_back(allocaInst);
};

for (auto &BB: *F) {
for (auto &instr: BB) {
if (auto *inst = dyn_cast<AllocaInst>(&instr)) {
allocaList.push_back(inst);
handle_alloca(inst);
}
}
}

for (auto* allocaInst: allocaList) {
// Insert after the alloca instruction
builder.SetInsertPoint(allocaInst->getNextNode());
Value* allocatedPtr = allocaInst;
Value *sizeVal = nullptr;
Type *allocatedType = allocaInst->getAllocatedType();
uint64_t typeSize = DL.getTypeAllocSize(allocatedType);
sizeVal = ConstantInt::get(size_ty, typeSize);
builder.CreateCall(getResolveStackObj(M), { allocatedPtr, sizeVal });
}

// Find low and high allocations and pass to resolve_invaliate_stack
if (allocaList.empty()) {
if (toFreeList.empty()) {
return;
}

auto invalidateFn = M->getOrInsertFunction(
"resolve_invalidate_stack",
FunctionType::get(void_ty, { ptr_ty, ptr_ty }, false)
);

// Stack grows down, so first allocation is high, last is low
// Hmm.. compiler seems to be reordering the allocas in ways
// that break this assumption
Expand All @@ -344,8 +367,8 @@ void instrumentAlloca(Function *F) {
if (auto *inst = dyn_cast<ReturnInst>(&instr)) {
builder.SetInsertPoint(inst);
// builder.CreateCall(invalidateFn, { low, high });
for (auto *alloca: allocaList) {
builder.CreateCall(invalidateFn, { alloca, alloca });
for (auto *alloca: toFreeList) {
builder.CreateCall(invalidateFn, { alloca });
}
}
}
Expand Down Expand Up @@ -631,4 +654,4 @@ void sanitizeMemInstBounds(Function *F, Vulnerability::RemediationStrategies str
instrumentGEP(F);
sanitizeMemcpy(F, strategy);
sanitizeLoadStore(F, strategy);
}
}