diff --git a/src/include/souffle/provenance/Explain.h b/src/include/souffle/provenance/Explain.h index d223e1a761c..86a3b4b2e6a 100644 --- a/src/include/souffle/provenance/Explain.h +++ b/src/include/souffle/provenance/Explain.h @@ -100,7 +100,13 @@ class Explain { return true; } query = parseTuple(command[1]); - printTree(prov.explain(query.first, query.second, ExplainConfig::getExplainConfig().depthLimit)); + try { + printTree(prov.explain( + query.first, query.second, ExplainConfig::getExplainConfig().depthLimit)); + } catch (const ValueReadException& e) { + printError(tfm::format("%s\n", e.what())); + return true; + } } else if (command[0] == "subproof") { std::pair> query; int label = -1; @@ -144,29 +150,40 @@ class Explain { printInfo(rules); printPrompt("Pick a rule number: "); - - std::string ruleNum = getInput(); - auto variables = prov.explainNegationGetVariables(query.first, query.second, std::stoi(ruleNum)); - - // @ and @non_matching are special sentinel values returned by ExplainProvenance - if (variables.size() == 1 && variables[0] == "@") { - printInfo("The tuple exists, cannot explain negation of it!\n"); - return true; - } else if (variables.size() == 1 && variables[0] == "@non_matching") { - printInfo("The variable bindings don't match, cannot explain!\n"); - return true; - } else if (variables.size() == 1 && variables[0] == "@fact") { - printInfo("The rule is a fact!\n"); + int ruleNum = 0; + try { + ruleNum = stoi(getInput()); + } catch (std::exception& e) { + printError("Invalid rule number\n"); return true; } - std::map varValues; - for (auto var : variables) { - printPrompt("Pick a value for " + var + ": "); - varValues[var] = getInput(); - } + try { + auto variables = prov.explainNegationGetVariables(query.first, query.second, ruleNum); + + // @ and @non_matching are special sentinel values returned by ExplainProvenance + if (variables.size() == 1 && variables[0] == "@") { + printInfo("The tuple exists, cannot explain negation of it!\n"); + return true; + } else if (variables.size() == 1 && variables[0] == "@non_matching") { + printInfo("The variable bindings don't match, cannot explain!\n"); + return true; + } else if (variables.size() == 1 && variables[0] == "@fact") { + printInfo("The rule is a fact!\n"); + return true; + } + + std::map varValues; + for (auto var : variables) { + printPrompt("Pick a value for " + var + ": "); + varValues[var] = getInput(); + } - printTree(prov.explainNegation(query.first, std::stoi(ruleNum), query.second, varValues)); + printTree(prov.explainNegation(query.first, ruleNum, query.second, varValues)); + } catch (const ValueReadException& e) { + printError(tfm::format("%s\n", e.what())); + return true; + } } else if (command[0] == "rule" && command.size() == 2) { auto query = split(command[1], ' '); if (query.size() != 2) { diff --git a/src/include/souffle/provenance/ExplainProvenance.h b/src/include/souffle/provenance/ExplainProvenance.h index e5031e621cc..cbaf8357755 100644 --- a/src/include/souffle/provenance/ExplainProvenance.h +++ b/src/include/souffle/provenance/ExplainProvenance.h @@ -35,6 +35,11 @@ namespace souffle { class TreeNode; +class ValueReadException : public std::runtime_error { +public: + ValueReadException(const std::string& what_arg) : std::runtime_error(what_arg) {} +}; + /** Equivalence class for variables in query command */ class Equivalence { public: @@ -240,15 +245,21 @@ class ExplainProvenance { } RamDomain valueRead(const char type, const std::string& value) const { - switch (type) { - case 'i': return ramBitCast(RamSignedFromString(value)); - case 'u': return ramBitCast(RamUnsignedFromString(value)); - case 'f': return ramBitCast(RamFloatFromString(value)); - case 's': - assert(2 <= value.size() && value[0] == '"' && value.back() == '"'); - return symTable.encode(value.substr(1, value.size() - 2)); - case 'r': fatal("not implemented"); - default: fatal("unhandled type attr code"); + try { + switch (type) { + case 'i': return ramBitCast(RamSignedFromString(value)); + case 'u': return ramBitCast(RamUnsignedFromString(value)); + case 'f': return ramBitCast(RamFloatFromString(value)); + case 's': + assert(2 <= value.size() && value[0] == '"' && value.back() == '"'); + return symTable.encode(value.substr(1, value.size() - 2)); + case 'r': fatal("not implemented"); + default: fatal("unhandled type attr code"); + } + } catch (const std::invalid_argument& e) { + throw ValueReadException(tfm::format("Invalid argument %s for type '%c'", value, type)); + } catch (const std::out_of_range& e) { + throw ValueReadException(tfm::format("Out of range value %s for type '%c'", value, type)); } } };