Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ REGRESS = scan \
name_validation \
jsonb_operators \
list_comprehension \
map_projection
map_projection \
concurrent

ifneq ($(EXTRA_TESTS),)
REGRESS += $(EXTRA_TESTS)
Expand Down
12 changes: 12 additions & 0 deletions regress/expected/concurrent.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
\! python3 regress/python/test_concurrent.py
Result: OK

Result: OK

Result: OK

Result: OK

Result: OK

All threads have finished execution.
170 changes: 170 additions & 0 deletions regress/python/test_concurrent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from contextlib import contextmanager
import psycopg2
from psycopg2 import sql
import threading

from concurrent.futures import ThreadPoolExecutor
from threading import Semaphore

sqls = [
""" SELECT * FROM cypher('test_graph', $$
MERGE (n:DDDDD {doc_id: 'f5ce4dc2'})
SET n.embedset_id='ae1b9b73', n.doc_id='f5ce4dc2', n.doc_hash='977b56ef'
$$) as (result agtype) """,
""" SELECT * FROM cypher('test_graph', $$
MERGE (n:EEEEE {doc_id: 'f5ce4dc2'})
SET n.embedset_id='ae1b9b73', n.doc_id='f5ce4dc2', n.doc_hash='1d7e79a0'
$$) as (result agtype) """,
""" SELECT * FROM cypher('test_graph', $$
MATCH (source:EEEEE {doc_id:'f5ce4dc2'})
MATCH (target:DDDDD {doc_id:'f5ce4dc2'})
WITH source, target
MERGE (source)-[r:DIRECTED]->(target)
SET r.embedset_id='ae1b9b73', r.doc_id='f5ce4dc2'
RETURN r
$$) as (result agtype) """,
""" SELECT * FROM cypher('test_graph', $$
MATCH (source:EEEEE {doc_id:'f5ce4dc2'})
MATCH (target:DDDDD {doc_id:'f5ce4dc2'})
WITH source, target
MERGE (source)-[r:DIRECTED]->(target)
SET r.embedset_id='ae1b9b73', r.doc_id='f5ce4dc2'
RETURN r
$$) as (result agtype) """,
""" SELECT * FROM cypher('test_graph', $$
MATCH (source:EEEEE {doc_id:'f5ce4dc2'})
MATCH (target:DDDDD {doc_id:'f5ce4dc2'})
WITH source, target
MERGE (source)-[r:DIRECTED]->(target)
SET r.embedset_id='ae1b9b73', r.doc_id='f5ce4dc2'
RETURN r
$$) as (result agtype) """,
]


class PieGraphConnector:
host: str
port: str
user: str
password: str
database: str
warehouse: str

def __init__(self, global_config: dict):
self.host = global_config.get("host", "")
self.port = global_config.get("port", "")
self.user = global_config.get("user", "")
self.password = global_config.get("password", "")
self.database = global_config.get("database", "")
self.warehouse = global_config.get("warehouse", "")

@contextmanager
def conn(self):
conn = None
if self.warehouse and self.warehouse != "":
options = "'-c warehouse=" + self.warehouse + "'"
conn = psycopg2.connect(
dbname=self.database,
user=self.user,
password=self.password,
host=self.host,
port=self.port,
options=options,
)
else:
conn = psycopg2.connect(
dbname=self.database,
user=self.user,
password=self.password,
host=self.host,
port=self.port,
)
conn.autocommit = True
with conn.cursor() as cursor:
cursor.execute("CREATE EXTENSION IF NOT EXISTS age;")
cursor.execute("LOAD 'age';")
cursor.execute("SET search_path = ag_catalog, '$user', public;")
try:
yield conn
finally:
conn.close()


class BoundedThreadPoolExecutor(ThreadPoolExecutor):
def __init__(self, max_workers=5, max_task_size=32, *args, **kwargs):
if max_task_size < max_workers:
raise ValueError(
"max_task_size should be greater than or equal to max_workers"
)
if max_workers is not None:
kwargs["max_workers"] = max_workers
super().__init__(*args, **kwargs)
self._semaphore = Semaphore(max_task_size)

def submit(self, fn, /, *args, **kwargs):
timeout = kwargs.get("timeout", None)
if self._semaphore.acquire(timeout=timeout):
future = super().submit(fn, *args, **kwargs)
future.add_done_callback(lambda _: self._semaphore.release())
return future
else:
raise TimeoutError("waiting for semaphore timeout")


db_config = {
"host": "127.0.0.1",
"database": "postgres",
"user": "postgres",
"password": "",
"port": "5432",
}

connector = PieGraphConnector(db_config)

def execute_sql(query):
try:
connection = psycopg2.connect(**db_config)
cursor = connection.cursor()
cursor.execute("CREATE EXTENSION IF NOT EXISTS age;")
cursor.execute("LOAD 'age';")
cursor.execute("SET search_path = ag_catalog, '$user', public;")

cursor.execute(query)

connection.commit()

result = cursor.fetchall()

except Exception as e:
print(f"Error executing query '{query}': {e}")
finally:
if connection:
cursor.close()
connection.close()

semaphore_graph = threading.Semaphore(20)

drop_graph = "SELECT * FROM drop_graph('test_graph', true)"
create_graph = "SELECT * FROM create_graph('test_graph')"

# execute_sql(drop_graph)
execute_sql(create_graph)

def _merge_exec_sql(query: str):
with semaphore_graph:
with connector.conn() as conn:
with conn.cursor() as cursor:
try:
cursor.execute(query)
result = cursor.fetchall()
print(f"Result: OK\n")
except Exception as e:
print(f"Error executing query '{query}': {e}")
conn.commit()

with BoundedThreadPoolExecutor() as executor:
executor.map(lambda q: _merge_exec_sql(q), sqls)

print("All threads have finished execution.")

execute_sql(drop_graph)
1 change: 1 addition & 0 deletions regress/sql/concurrent.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
\! python3 regress/python/test_concurrent.py
99 changes: 71 additions & 28 deletions src/backend/parser/cypher_clause.c
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,14 @@
#include "catalog/ag_graph.h"
#include "catalog/ag_label.h"
#include "commands/label_commands.h"
#include "common/hashfn.h"
#include "parser/cypher_analyze.h"
#include "parser/cypher_clause.h"
#include "parser/cypher_expr.h"
#include "parser/cypher_item.h"
#include "parser/cypher_parse_agg.h"
#include "parser/cypher_transform_entity.h"
#include "storage/lock.h"
#include "utils/ag_cache.h"
#include "utils/ag_func.h"
#include "utils/ag_guc.h"
Expand Down Expand Up @@ -5872,15 +5874,23 @@ transform_create_cypher_edge(cypher_parsestate *cpstate, List **target_list,
/* create the label entry if it does not exist */
if (!label_exists(edge->label, cpstate->graph_oid))
{
LOCKTAG tag;
uint32 key;
List *parent;

rv = get_label_range_var(cpstate->graph_name, cpstate->graph_oid,
AG_DEFAULT_LABEL_EDGE);
key = hash_bytes((const unsigned char *)edge->label, strlen(edge->label));
SET_LOCKTAG_ADVISORY(tag, MyDatabaseId, key, cpstate->graph_oid, 3);
(void) LockAcquire(&tag, ExclusiveLock, false, false);
if (!label_exists(edge->label, cpstate->graph_oid))
{
rv = get_label_range_var(cpstate->graph_name, cpstate->graph_oid,
AG_DEFAULT_LABEL_EDGE);

parent = list_make1(rv);
parent = list_make1(rv);

create_label(cpstate->graph_name, edge->label, LABEL_TYPE_EDGE,
parent);
create_label(cpstate->graph_name, edge->label, LABEL_TYPE_EDGE,
parent);
}
}

/* lock the relation of the label */
Expand Down Expand Up @@ -6149,15 +6159,23 @@ transform_create_cypher_new_node(cypher_parsestate *cpstate,
/* create the label entry if it does not exist */
if (!label_exists(node->label, cpstate->graph_oid))
{
LOCKTAG tag;
uint32 key;
List *parent;

rv = get_label_range_var(cpstate->graph_name, cpstate->graph_oid,
AG_DEFAULT_LABEL_VERTEX);
key = hash_bytes((const unsigned char *)node->label, strlen(node->label));
SET_LOCKTAG_ADVISORY(tag, MyDatabaseId, key, cpstate->graph_oid, 3);
(void) LockAcquire(&tag, ExclusiveLock, false, false);
if (!label_exists(node->label, cpstate->graph_oid))
{
rv = get_label_range_var(cpstate->graph_name, cpstate->graph_oid,
AG_DEFAULT_LABEL_VERTEX);

parent = list_make1(rv);
parent = list_make1(rv);

create_label(cpstate->graph_name, node->label, LABEL_TYPE_VERTEX,
parent);
create_label(cpstate->graph_name, node->label, LABEL_TYPE_VERTEX,
parent);
}
}

rel->flags = CYPHER_TARGET_NODE_FLAG_INSERT;
Expand Down Expand Up @@ -7222,19 +7240,36 @@ transform_merge_cypher_edge(cypher_parsestate *cpstate, List **target_list,
/* check to see if the label exists, create the label entry if it does not. */
if (edge->label && !label_exists(edge->label, cpstate->graph_oid))
{
LOCKTAG tag;
uint32 key;
List *parent;

/*
* setup the default edge table as the parent table, that we
* will inherit from.
* When merging nodes or edges concurrently, there is label with the same
* name created by different transactions. Advisory lock is acquired before
* creating label, and then check if label already exists. Note, the lock is
* not released until current transaction is over. This can ensure that the
* new tuple inserted in ag_label catalog table will be sent out, so other
* transactions can receive it when checking label exists after acquiring lock.
*/
rv = get_label_range_var(cpstate->graph_name, cpstate->graph_oid,
AG_DEFAULT_LABEL_EDGE);
key = hash_bytes((const unsigned char *)edge->label, strlen(edge->label));
SET_LOCKTAG_ADVISORY(tag, MyDatabaseId, key, cpstate->graph_oid, 3);
(void) LockAcquire(&tag, ExclusiveLock, false, false);
if (!label_exists(edge->label, cpstate->graph_oid))
{
/*
* setup the default edge table as the parent table, that we
* will inherit from.
*/
rv = get_label_range_var(cpstate->graph_name, cpstate->graph_oid,
AG_DEFAULT_LABEL_EDGE);

parent = list_make1(rv);
parent = list_make1(rv);

/* create the label */
create_label(cpstate->graph_name, edge->label, LABEL_TYPE_EDGE,
parent);
/* create the label */
create_label(cpstate->graph_name, edge->label, LABEL_TYPE_EDGE,
parent);
}
}

/* lock the relation of the label */
Expand Down Expand Up @@ -7357,20 +7392,28 @@ transform_merge_cypher_node(cypher_parsestate *cpstate, List **target_list,
/* check to see if the label exists, create the label entry if it does not. */
if (node->label && !label_exists(node->label, cpstate->graph_oid))
{
LOCKTAG tag;
uint32 key;
List *parent;

/*
* setup the default vertex table as the parent table, that we
* will inherit from.
*/
rv = get_label_range_var(cpstate->graph_name, cpstate->graph_oid,
AG_DEFAULT_LABEL_VERTEX);
key = hash_bytes((const unsigned char *)node->label, strlen(node->label));
SET_LOCKTAG_ADVISORY(tag, MyDatabaseId, key, cpstate->graph_oid, 3);
(void) LockAcquire(&tag, ExclusiveLock, false, false);
if (!label_exists(node->label, cpstate->graph_oid))
{
/*
* setup the default vertex table as the parent table, that we
* will inherit from.
*/
rv = get_label_range_var(cpstate->graph_name, cpstate->graph_oid,
AG_DEFAULT_LABEL_VERTEX);

parent = list_make1(rv);
parent = list_make1(rv);

/* create the label */
create_label(cpstate->graph_name, node->label, LABEL_TYPE_VERTEX,
parent);
/* create the label */
create_label(cpstate->graph_name, node->label, LABEL_TYPE_VERTEX,
parent);
}
}

rel->flags |= CYPHER_TARGET_NODE_FLAG_INSERT;
Expand Down
Loading