Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 69 additions & 1 deletion src/agents/query_engine/PatternMatchingQueryProcessor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// clang-format off

#ifndef LOG_LEVEL
#define LOG_LEVEL INFO_LEVEL
#define LOG_LEVEL DEBUG_LEVEL
#endif
#include "Logger.h"

Expand All @@ -17,6 +17,7 @@
#include "MettaParserActions.h"
#include "Node.h"
#include "Or.h"
#include "Chain.h"
#include "PatternMatchingQueryProxy.h"
#include "ServiceBus.h"
#include "Sink.h"
Expand All @@ -33,6 +34,7 @@ using namespace attention_broker;

string PatternMatchingQueryProcessor::AND = "AND";
string PatternMatchingQueryProcessor::OR = "OR";
string PatternMatchingQueryProcessor::CHAIN = "CHAIN";

// -------------------------------------------------------------------------------------------------
// Constructors and destructors
Expand Down Expand Up @@ -274,6 +276,8 @@ shared_ptr<QueryElement> PatternMatchingQueryProcessor::setup_query_tree(
(query_tokens[cursor] == LinkSchema::ATOM) || (query_tokens[cursor] == AND) ||
(query_tokens[cursor] == OR)) {
cursor += 2;
} else if (query_tokens[cursor] == CHAIN) {
cursor += 4;
} else {
Utils::error("Invalid token in query: " + query_tokens[cursor]);
}
Expand Down Expand Up @@ -309,6 +313,8 @@ shared_ptr<QueryElement> PatternMatchingQueryProcessor::setup_query_tree(
if (proxy->parameters.get<bool>(BaseQueryProxy::UNIQUE_ASSIGNMENT_FLAG)) {
element_stack.push(build_unique_assignment_filter(proxy, cursor, element_stack));
}
} else if (query_tokens[cursor] == CHAIN) {
element_stack.push(build_chain(proxy, cursor, element_stack));
} else {
Utils::error("Invalid token " + query_tokens[cursor] + " in PATTERN_MATCHING_QUERY message");
}
Expand All @@ -326,6 +332,7 @@ shared_ptr<QueryElement> PatternMatchingQueryProcessor::build_link_template(
shared_ptr<PatternMatchingQueryProxy> proxy,
unsigned int cursor,
stack<shared_ptr<QueryElement>>& element_stack) {
LOG_DEBUG("Building LinkTemplate...");
const vector<string> query_tokens = proxy->get_query_tokens();
unsigned int arity = std::stoi(query_tokens[cursor + 2]);
if (element_stack.size() < arity) {
Expand All @@ -346,6 +353,8 @@ shared_ptr<QueryElement> PatternMatchingQueryProcessor::build_link_template(
proxy->parameters.get<bool>(PatternMatchingQueryProxy::DISREGARD_IMPORTANCE_FLAG),
proxy->parameters.get<bool>(PatternMatchingQueryProxy::UNIQUE_VALUE_FLAG),
proxy->parameters.get<bool>(BaseQueryProxy::USE_LINK_TEMPLATE_CACHE));
LOG_DEBUG("New LinkTemplate: " + link_template->to_string());
LOG_DEBUG("Building LinkTemplate... Done.");
return link_template;
}

Expand Down Expand Up @@ -411,9 +420,11 @@ shared_ptr<QueryElement> PatternMatchingQueryProcessor::build_and(
link_templates.push_back(element_stack.top()); \
link_template->build(); \
clauses[i] = link_template->get_source_element(); \
LOG_DEBUG("OR input[" << i << "]: " << element_stack.top()->to_string()); \
} else { \
if (element_stack.top()->is_operator) { \
clauses[i] = element_stack.top(); \
LOG_DEBUG("OR input[" << i << "]: " << element_stack.top()->to_string()); \
} else { \
Utils::error("All OR clauses are supposed to be LinkTemplate or Operator"); \
} \
Expand All @@ -427,6 +438,7 @@ shared_ptr<QueryElement> PatternMatchingQueryProcessor::build_or(
shared_ptr<PatternMatchingQueryProxy> proxy,
unsigned int cursor,
stack<shared_ptr<QueryElement>>& element_stack) {
LOG_DEBUG("Building OR operator");
const vector<string> query_tokens = proxy->get_query_tokens();
unsigned int num_clauses = std::stoi(query_tokens[cursor + 1]);
if (element_stack.size() < num_clauses) {
Expand All @@ -452,6 +464,62 @@ shared_ptr<QueryElement> PatternMatchingQueryProcessor::build_or(
return NULL; // Just to avoid warnings. This is not actually reachable.
}

shared_ptr<QueryElement> PatternMatchingQueryProcessor::build_chain(
shared_ptr<PatternMatchingQueryProxy> proxy,
unsigned int cursor,
stack<shared_ptr<QueryElement>>& element_stack) {
LOG_DEBUG("Building CHAIN operator...");
const vector<string> query_tokens = proxy->get_query_tokens();
QueryAnswerElement link_selector;
if (isdigit(static_cast<unsigned char>(query_tokens[cursor + 1][0]))) {
link_selector.set(Utils::string_to_uint(query_tokens[cursor + 1]));
LOG_DEBUG("Link selector is handle index: " + query_tokens[cursor + 1]);
} else {
link_selector.set(query_tokens[cursor + 1]);
LOG_DEBUG("Link selector is variable: " + query_tokens[cursor + 1]);
}
unsigned int tail_reference = Utils::string_to_uint(query_tokens[cursor + 2]);
unsigned int head_reference = Utils::string_to_uint(query_tokens[cursor + 3]);
LOG_DEBUG("Tail reference: " + std::to_string(tail_reference));
LOG_DEBUG("Head reference: " + std::to_string(head_reference));

if (element_stack.size() < 3) {
Utils::error(
"PATTERN_MATCHING_QUERY message: parse error in tokens - too few arguments for "
"CHAIN");
}

shared_ptr<Terminal> source = dynamic_pointer_cast<Terminal>(element_stack.top());
element_stack.pop();
shared_ptr<Terminal> target = dynamic_pointer_cast<Terminal>(element_stack.top());
element_stack.pop();
LOG_DEBUG("Source terminal: " + source->to_string());
LOG_DEBUG("Target terminal: " + target->to_string());
LOG_DEBUG("Source handle: " + source->compute_handle());
LOG_DEBUG("Target handle: " + target->compute_handle());

array<shared_ptr<QueryElement>, 1> clauses;
clauses[0] = element_stack.top();
shared_ptr<LinkTemplate> link_template = dynamic_pointer_cast<LinkTemplate>(clauses[0]);
if (link_template != nullptr) {
link_template->build();
clauses[0] = link_template->get_source_element();
}
LOG_DEBUG("Input: " + clauses[0]->to_string());
element_stack.pop();

auto chain_operator = make_shared<Chain>(clauses,
link_template,
source->compute_handle(),
target->compute_handle(),
link_selector,
tail_reference,
head_reference);
LOG_DEBUG("Building CHAIN operator... DONE");

return chain_operator;
}

shared_ptr<QueryElement> PatternMatchingQueryProcessor::build_link(
shared_ptr<PatternMatchingQueryProxy> proxy,
unsigned int cursor,
Expand Down
5 changes: 5 additions & 0 deletions src/agents/query_engine/PatternMatchingQueryProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ class PatternMatchingQueryProcessor : public BusCommandProcessor {
unsigned int cursor,
stack<shared_ptr<QueryElement>>& element_stack);

shared_ptr<QueryElement> build_chain(shared_ptr<PatternMatchingQueryProxy> proxy,
unsigned int cursor,
stack<shared_ptr<QueryElement>>& element_stack);

shared_ptr<QueryElement> build_link(shared_ptr<PatternMatchingQueryProxy> proxy,
unsigned int cursor,
stack<shared_ptr<QueryElement>>& element_stack);
Expand All @@ -88,6 +92,7 @@ class PatternMatchingQueryProcessor : public BusCommandProcessor {
shared_ptr<AtomDB> atomdb;
static string AND;
static string OR;
static string CHAIN;
};

} // namespace atomdb
17 changes: 17 additions & 0 deletions src/agents/query_engine/QueryAnswer.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "Assignment.h"
#include "QueryAnswer.h"
#include "Utils.h"
#include "expression_hasher.h"

using namespace std;
Expand Down Expand Up @@ -34,6 +35,22 @@ class QueryAnswerElement {
this->name = other.name;
return *this;
}
void set(const string& key) {
if (this->type == UNDEFINED) {
this->type = VARIABLE;
this->name = key;
} else {
Utils::error("Invalid attempt to reset a QueryAnswerElement");
}
}
void set(unsigned int key) {
if (this->type == UNDEFINED) {
this->type = HANDLE;
this->index = key;
} else {
Utils::error("Invalid attempt to reset a QueryAnswerElement");
}
}
string to_string() {
if (this->type == HANDLE) {
return "_" + std::to_string(this->index);
Expand Down
2 changes: 2 additions & 0 deletions src/agents/query_engine/query_element/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ cc_library(
hdrs = ["Chain.h"],
deps = [
":operator",
"//agents/query_engine/query_element:link_template",
"//agents/query_engine/query_element:source",
"//atomdb:atomdb_singleton",
"//commons:commons_lib",
"//commons/atoms:atoms_lib",
Expand Down
124 changes: 65 additions & 59 deletions src/agents/query_engine/query_element/Chain.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,27 @@ static string convert_handle(const string& handle) {
// Public methods

Chain::Chain(const array<shared_ptr<QueryElement>, 1>& clauses,
shared_ptr<LinkTemplate> link_template,
const string& source_handle,
const string& target_handle)
: Operator<1>(clauses), source_handle(source_handle), target_handle(target_handle) {
const string& target_handle,
const QueryAnswerElement& link_selector,
unsigned int tail_reference,
unsigned int head_reference)
: Operator<1>(clauses),
input_link_template(link_template),
source_handle(source_handle),
target_handle(target_handle),
link_selector(link_selector),
tail_reference(tail_reference),
head_reference(head_reference) {
initialize(clauses);
}

Chain::Chain(const array<shared_ptr<QueryElement>, 1>& clauses,
const string& source_handle,
const string& target_handle)
: Chain(clauses, nullptr, source_handle, target_handle, QueryAnswerElement(0), 1, 2) {}

Chain::~Chain() {
LOG_DEBUG("Chain::~Chain() BEGIN");
graceful_shutdown();
Expand Down Expand Up @@ -201,6 +216,7 @@ bool Chain::PathFinder::thread_one_step() {
<< "Pushing new path: " << new_path.to_string());
base_heap->push(new_path, new_path.path_sti);
} else {
LOG_DEBUG("[PATH_FINDER] Discarding because candidate would lead to a cycle.");
count_cycles++;
}
}
Expand All @@ -219,6 +235,7 @@ bool Chain::PathFinder::thread_one_step() {
void Chain::refeed_paths() {
while (!this->refeeding_buffer.empty()) {
Path path = refeeding_buffer.front_and_pop();
LOG_DEBUG("Refeeding: " << path.to_string());
if (path.forward_flag) {
this->source_index[this->source_handle]->push(path, path.path_sti);
} else {
Expand Down Expand Up @@ -256,52 +273,52 @@ bool Chain::thread_one_step() {
if ((answer = dynamic_cast<QueryAnswer*>(this->input_buffer[0]->pop_query_answer())) != NULL) {
LOG_DEBUG("[CHAIN OPERATOR] "
<< "New query answer: " << answer->to_string());
for (string handle : answer->handles) {
auto iterator = this->known_links.find(handle);
if (iterator == this->known_links.end()) {
this->known_links.insert(iterator, handle);
shared_ptr<Link> link =
dynamic_pointer_cast<Link>(AtomDBSingleton::get_instance()->get_atom(handle));
if (link == nullptr) {
Utils::error("Invalid query answer in Chain operator.");
} else {
LOG_DEBUG("[CHAIN OPERATOR] "
<< "Valid link");
}
string handle = answer->get(this->link_selector);
auto iterator = this->known_links.find(handle);
if (iterator == this->known_links.end()) {
this->known_links.insert(iterator, handle);
shared_ptr<Link> link =
dynamic_pointer_cast<Link>(AtomDBSingleton::get_instance()->get_atom(handle));
if (link == nullptr) {
Utils::error("Invalid query answer in Chain operator.");
} else {
LOG_DEBUG("[CHAIN OPERATOR] "
<< "New link: " << link->to_string());
if (link->arity() == 3) {
{
lock_guard<mutex> semaphore(this->source_index_mutex);
for (unsigned int i = 1; i <= 2; i++) {
if (this->source_index.find(link->targets[i]) ==
this->source_index.end()) {
this->source_index[link->targets[i]] = make_shared<HeapType>();
}
<< "Valid link");
}
LOG_DEBUG("[CHAIN OPERATOR] "
<< "New link: " << link->to_string());
if (link->arity() > max(this->tail_reference, this->head_reference)) {
string tail = link->targets[this->tail_reference];
string head = link->targets[this->head_reference];
{
lock_guard<mutex> semaphore(this->source_index_mutex);
for (string key : {tail, head}) {
if (this->source_index.find(key) == this->source_index.end()) {
this->source_index[key] = make_shared<HeapType>();
}
this->source_index[link->targets[1]]->push(Path(link, answer, true),
answer->importance);
}
{
lock_guard<mutex> semaphore(this->target_index_mutex);
for (unsigned int i = 1; i <= 2; i++) {
if (this->target_index.find(link->targets[i]) ==
this->target_index.end()) {
this->target_index[link->targets[i]] = make_shared<HeapType>();
}
this->source_index[tail]->push(Path(tail, head, answer, true),
answer->importance);
}
{
lock_guard<mutex> semaphore(this->target_index_mutex);
for (string key : {tail, head}) {
if (this->target_index.find(key) == this->target_index.end()) {
this->target_index[key] = make_shared<HeapType>();
}
this->target_index[link->targets[2]]->push(
Path(link, QueryAnswer::copy(answer), false), answer->importance);
}
} else {
Utils::error("Invalid Link " + link->to_string() + " with arity " +
std::to_string(link->arity()) + " in CHAIN operator.");
break;
this->target_index[head]->push(
Path(tail, head, QueryAnswer::copy(answer), false), answer->importance);
}
} else {
LOG_DEBUG("[CHAIN OPERATOR] "
<< "Discarding already inserted handle: " << convert_handle(handle));
Utils::error("Invalid Link " + link->to_string() + " with arity " +
std::to_string(link->arity()) + " in CHAIN operator. Tail reference: " +
std::to_string(this->tail_reference) +
". Head reference: " + std::to_string(this->head_reference));
}
} else {
LOG_DEBUG("[CHAIN OPERATOR] "
<< "Discarding already inserted handle: " << convert_handle(handle));
}
refeed_paths();
return true;
Expand All @@ -322,20 +339,12 @@ bool Chain::thread_one_step() {
void Chain::report_path(Path& path) {
QueryAnswer* query_answer = new QueryAnswer(path.path_sti);
if (path.forward_flag) {
for (auto pair : path.links) {
query_answer->add_handle(pair.first->handle()); // TODO change to use handle in query_answer
if (!query_answer->merge(pair.second.get())) {
Utils::error("Incompatible assignments in Chain operator answer: " +
query_answer->to_string() + " + " + pair.second->to_string());
}
for (auto pair : path.edges) {
query_answer->add_handle(pair.second->get(this->link_selector));
}
} else {
for (auto pair = path.links.rbegin(); pair != path.links.rend(); ++pair) {
query_answer->add_handle(pair->first->handle());
if (!query_answer->merge(pair->second.get())) {
Utils::error("Incompatible assignments in Chain operator answer: " +
query_answer->to_string() + " + " + pair->second->to_string());
}
for (auto pair = path.edges.rbegin(); pair != path.edges.rend(); ++pair) {
query_answer->add_handle(pair->second->get(this->link_selector));
}
}
string answer_hash = Hasher::composite_handle(query_answer->handles);
Expand Down Expand Up @@ -393,20 +402,17 @@ string Chain::Path::to_string() {
bool first = true;
string last_handle = "";
string check_handle = "";
for (auto pair : this->links) {
for (auto pair : this->edges) {
if (first) {
first = false;
last_handle =
convert_handle(this->forward_flag ? pair.first->targets[1] : pair.first->targets[2]);
last_handle = convert_handle(this->forward_flag ? pair.first.first : pair.first.second);
answer = last_handle;
}
check_handle =
convert_handle(this->forward_flag ? pair.first->targets[1] : pair.first->targets[2]);
check_handle = convert_handle(this->forward_flag ? pair.first.first : pair.first.second);
if (check_handle != last_handle) {
LOG_ERROR("Invalid Path");
}
last_handle =
convert_handle(this->forward_flag ? pair.first->targets[2] : pair.first->targets[1]);
last_handle = convert_handle(this->forward_flag ? pair.first.second : pair.first.first);
answer += this->forward_flag ? " -> " : " <- ";
answer += last_handle;
}
Expand Down
Loading
Loading