Skip to content
75 changes: 50 additions & 25 deletions resolve-cveassert/src/CVEAssert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,31 +184,60 @@ struct LabelCVEPass : public PassInfoMixin<LabelCVEPass> {
sanitizeIntOverflow(&F, strategy);
}

/// For each function, if it matches the target function name, insert calls to
/// the vulnerability handlers as specified in the JSON. Each call receives
/// the triggering argument parsed from the JSON.
PreservedAnalyses runOnFunction(Function &F, ModuleAnalysisManager &MAM,
Vulnerability &vuln) {
/// Return true if F's name (raw or demangled) contains `targetName
///
/// Always returns true if targetName is empty
bool nameMatches(Function &F, const std::string &demangledName,
const std::string &targetName) {
// Empty function name matches all functions
if (targetName.empty()) {
return true;
}

// First check demangled name
if (demangledName.find(targetName) != std::string::npos) {
return true;
}

// Next fallback to raw symbols
if (F.getName().str().find(targetName) != std::string::npos) {
return true;
}

return false;
}

/// Return true if `F` meets instrumentation critera for vuln
bool shouldInstrument(Function &F, Vulnerability &vuln) {
// Skip noinstrument functions
if (F.getMetadata("resolve.noinstrument")) {
return false;
}

char *demangledNamePtr = llvm::itaniumDemangle(F.getName().str(), false);
std::string demangledName(demangledNamePtr ?: "");
auto result = PreservedAnalyses::all();

if (F.getMetadata("resolve.noinstrument")) { return result; }

if (CVE_ASSERT_DEBUG) {
errs() << "[CVEAssert] Trying fn " << F.getName()
<< " Demangled name: " << demangledName << "\n";
}

raw_ostream &out = errs();
return nameMatches(F, demangledName, vuln.TargetFunctionName);
}

if (vuln.TargetFunctionName.empty() ||
(demangledName.find(vuln.TargetFunctionName) == std::string::npos &&
F.getName().str().find(vuln.TargetFunctionName) ==
std::string::npos)) {
/// For each function, if it matches the target function name, insert calls to
/// the vulnerability handlers as specified in the JSON. Each call receives
/// the triggering argument parsed from the JSON.
PreservedAnalyses runOnFunction(Function &F, ModuleAnalysisManager &MAM,
Vulnerability &vuln) {
auto result = PreservedAnalyses::all();

if (!shouldInstrument(F, vuln)) {
return result;
}

raw_ostream &out = errs();

out << "[CVEAssert] === Pre Instrumented IR === \n";
out << F;

Expand All @@ -223,9 +252,9 @@ struct LabelCVEPass : public PassInfoMixin<LabelCVEPass> {
}

if (vuln.Strategy == Vulnerability::RemediationStrategies::NONE) {
errs() << "[CVEAssert] NONE strategy selected for " << vuln.TargetFileName
<< ":" << vuln.TargetFunctionName << "...\n";
errs() << "[CVEAssert] Skipping remediation\n";
out << "[CVEAssert] NONE strategy selected for " << vuln.TargetFileName
<< ":" << vuln.TargetFunctionName << "...\n";
out << "[CVEAssert] Skipping remediation\n";
return result;
}

Expand Down Expand Up @@ -275,20 +304,16 @@ struct LabelCVEPass : public PassInfoMixin<LabelCVEPass> {
break;

default:
errs() << "[CVEAssert] Error: CWE " << vuln.WeaknessID
<< " not implemented\n";
out << "[CVEAssert] Error: CWE " << vuln.WeaknessID
<< " not implemented\n";
break;
}

out << "[CVEAssert] === Post Instrumented IR === \n";
out << F;

if (verifyFunction(F, &out)) {
report_fatal_error("[CVEAssert] We broke something");
}
validateIR(&F);

errs() << "[CVEAssert] Inserted vulnerability handler calls in function "
<< vuln.TargetFileName << ":" << vuln.TargetFunctionName << "\n";
out << "[CVEAssert] Inserted vulnerability handler calls in function "
<< vuln.TargetFileName << ":" << vuln.TargetFunctionName << "\n";
return result;
}

Expand Down