diff --git a/inc/trilogy/client.h b/inc/trilogy/client.h index afce6273..59f5f1b2 100644 --- a/inc/trilogy/client.h +++ b/inc/trilogy/client.h @@ -63,6 +63,8 @@ */ typedef trilogy_column_packet_t trilogy_column_t; +typedef struct trilogy_stmt trilogy_stmt_t; + /* trilogy_conn_t - The Trilogy client's instance type. * * This type is shared for the non-blocking and blocking versions of the API. @@ -82,6 +84,7 @@ typedef struct { uint16_t server_status; trilogy_sock_t *socket; + trilogy_stmt_t *prepared_statements; // private: uint8_t recv_buff[TRILOGY_DEFAULT_BUF_SIZE]; @@ -619,7 +622,16 @@ int trilogy_stmt_prepare_send(trilogy_conn_t *conn, const char *stmt, size_t stm /* trilogy_stmt_t - The trilogy client's prepared statement type. */ -typedef trilogy_stmt_ok_packet_t trilogy_stmt_t; + +struct trilogy_stmt { + trilogy_stmt_t *prev; + trilogy_stmt_t *next; + uint32_t id; + uint16_t column_count; + uint16_t parameter_count; + uint16_t warning_count; + trilogy_conn_t *connection; +}; /* trilogy_stmt_prepare_recv - Read the prepared statement prepare command response * from the MySQL-compatible server. diff --git a/src/blocking.c b/src/blocking.c index 3afa562e..119e9aa0 100644 --- a/src/blocking.c +++ b/src/blocking.c @@ -263,6 +263,8 @@ int trilogy_close(trilogy_conn_t *conn) int trilogy_stmt_prepare(trilogy_conn_t *conn, const char *stmt, size_t stmt_len, trilogy_stmt_t *stmt_out) { + memset(stmt_out, 0, sizeof(trilogy_stmt_t)); + int rc = trilogy_stmt_prepare_send(conn, stmt, stmt_len); if (rc == TRILOGY_AGAIN) { @@ -276,6 +278,16 @@ int trilogy_stmt_prepare(trilogy_conn_t *conn, const char *stmt, size_t stmt_len while (1) { rc = trilogy_stmt_prepare_recv(conn, stmt_out); + if (rc == TRILOGY_OK) { + stmt_out->connection = conn; + if (conn->prepared_statements) { + stmt_out->next = conn->prepared_statements; + stmt_out->next->prev = stmt_out; + } + conn->prepared_statements = stmt_out; + return rc; + } + if (rc != TRILOGY_AGAIN) { return rc; } @@ -365,6 +377,10 @@ int trilogy_stmt_reset(trilogy_conn_t *conn, trilogy_stmt_t *stmt) int trilogy_stmt_close(trilogy_conn_t *conn, trilogy_stmt_t *stmt) { + if (!stmt->connection || conn != stmt->connection) { + // User BUG!!! Return an error or crash? + } + int rc = trilogy_stmt_close_send(conn, stmt); if (rc == TRILOGY_AGAIN) { @@ -375,6 +391,20 @@ int trilogy_stmt_close(trilogy_conn_t *conn, trilogy_stmt_t *stmt) return rc; } + if (stmt->prev == NULL) { + // assert stmt->connection->prepared_statements == stmt + stmt->connection->prepared_statements = stmt->next; + if (stmt->next) { + stmt->next->prev = NULL; + } + } else { + stmt->prev->next = stmt->next; + if (stmt->next) { + stmt->next->prev = stmt->prev; + } + } + + stmt->connection = NULL; return TRILOGY_OK; } diff --git a/src/client.c b/src/client.c index e78abd6e..3d3907b8 100644 --- a/src/client.c +++ b/src/client.c @@ -145,6 +145,8 @@ int trilogy_init(trilogy_conn_t *conn) conn->recv_buff_pos = 0; conn->recv_buff_len = 0; + conn->prepared_statements = NULL; + trilogy_packet_parser_init(&conn->packet_parser, &packet_parser_callbacks); conn->packet_parser.user_data = &conn->packet_buffer; @@ -765,6 +767,12 @@ void trilogy_free(trilogy_conn_t *conn) conn->socket = NULL; } + trilogy_stmt_t *stmt = conn->prepared_statements; + while (stmt) { + stmt->connection = NULL; + stmt = stmt->next; + } + trilogy_buffer_free(&conn->packet_buffer); } @@ -803,13 +811,19 @@ int trilogy_stmt_prepare_recv(trilogy_conn_t *conn, trilogy_stmt_t *stmt_out) switch (current_packet_type(conn)) { case TRILOGY_PACKET_OK: { - err = trilogy_parse_stmt_ok_packet(conn->packet_buffer.buff, conn->packet_buffer.len, stmt_out); + trilogy_stmt_ok_packet_t out_packet; + err = trilogy_parse_stmt_ok_packet(conn->packet_buffer.buff, conn->packet_buffer.len, &out_packet); if (err < 0) { return err; } conn->warning_count = stmt_out->warning_count; + stmt_out->connection = conn; + stmt_out->id = out_packet.id; + stmt_out->column_count = out_packet.column_count; + stmt_out->parameter_count = out_packet.parameter_count; + stmt_out->warning_count = out_packet.warning_count; return TRILOGY_OK; } diff --git a/test/blocking_test.c b/test/blocking_test.c index cd37a1c3..eddbc575 100644 --- a/test/blocking_test.c +++ b/test/blocking_test.c @@ -176,6 +176,7 @@ TEST test_blocking_stmt_prepare() int err = trilogy_stmt_prepare(&conn, query, strlen(query), &stmt); ASSERT_OK(err); + ASSERT(stmt.connection); ASSERT_EQ(1, stmt.parameter_count); @@ -204,6 +205,7 @@ TEST test_blocking_stmt_execute_str() int err = trilogy_stmt_prepare(&conn, query, strlen(query), &stmt); ASSERT_OK(err); + ASSERT(stmt.connection); ASSERT_EQ(1, stmt.parameter_count); @@ -258,6 +260,7 @@ TEST test_blocking_stmt_execute_integer() int err = trilogy_stmt_prepare(&conn, query, strlen(query), &stmt); ASSERT_OK(err); + ASSERT(stmt.connection); ASSERT_EQ(1, stmt.parameter_count); @@ -332,6 +335,7 @@ TEST test_blocking_stmt_execute_double() int err = trilogy_stmt_prepare(&conn, query, strlen(query), &stmt); ASSERT_OK(err); + ASSERT(stmt.connection); ASSERT_EQ(1, stmt.parameter_count); @@ -384,6 +388,7 @@ TEST test_blocking_stmt_execute_float() { int err = trilogy_stmt_prepare(&conn, query, strlen(query), &stmt); ASSERT_OK(err); + ASSERT(stmt.connection); ASSERT_EQ(1, stmt.parameter_count); @@ -443,6 +448,7 @@ TEST test_blocking_stmt_execute_long() int err = trilogy_stmt_prepare(&conn, query, strlen(query), &stmt); ASSERT_OK(err); + ASSERT(stmt.connection); ASSERT_EQ(1, stmt.parameter_count); @@ -516,6 +522,7 @@ TEST test_blocking_stmt_execute_short() { int err = trilogy_stmt_prepare(&conn, query, strlen(query), &stmt); ASSERT_OK(err); + ASSERT(stmt.connection); ASSERT_EQ(1, stmt.parameter_count); @@ -589,6 +596,7 @@ TEST test_blocking_stmt_execute_tiny() { int err = trilogy_stmt_prepare(&conn, query, strlen(query), &stmt); ASSERT_OK(err); + ASSERT(stmt.connection); ASSERT_EQ(1, stmt.parameter_count); @@ -663,6 +671,7 @@ TEST test_blocking_stmt_execute_datetime() int err = trilogy_stmt_prepare(&conn, query, strlen(query), &stmt); ASSERT_OK(err); + ASSERT(stmt.connection); ASSERT_EQ(0, stmt.parameter_count); @@ -714,6 +723,7 @@ TEST test_blocking_stmt_execute_time() int err = trilogy_stmt_prepare(&conn, query, strlen(query), &stmt); ASSERT_OK(err); + ASSERT(stmt.connection); ASSERT_EQ(0, stmt.parameter_count); @@ -762,6 +772,7 @@ TEST test_blocking_stmt_execute_year() int err = trilogy_stmt_prepare(&conn, query, strlen(query), &stmt); ASSERT_OK(err); + ASSERT(stmt.connection); ASSERT_EQ(0, stmt.parameter_count); @@ -808,6 +819,7 @@ TEST test_blocking_stmt_reset() int err = trilogy_stmt_prepare(&conn, query, strlen(query), &stmt); ASSERT_OK(err); + ASSERT(stmt.connection); ASSERT_EQ(1, stmt.parameter_count); @@ -839,6 +851,7 @@ TEST test_blocking_stmt_close() int err = trilogy_stmt_prepare(&conn, query, strlen(query), &stmt); ASSERT_OK(err); + ASSERT(stmt.connection); ASSERT_EQ(1, stmt.parameter_count); @@ -851,11 +864,82 @@ TEST test_blocking_stmt_close() trilogy_column_packet_t column_def; err = trilogy_read_full_column(&conn, &column_def); ASSERT_OK(err); + ASSERT_EQ(conn.prepared_statements, &stmt); + + const char *query2 = "SELECT YEAR('2022-01-31')"; + trilogy_stmt_t stmt2; + + err = trilogy_stmt_prepare(&conn, query2, strlen(query2), &stmt2); + ASSERT_OK(err); + ASSERT(stmt2.connection); + + ASSERT_EQ(0, stmt2.parameter_count); + + trilogy_column_packet_t param2; + err = trilogy_read_full_column(&conn, ¶m2); + ASSERT_OK(err); + + ASSERT_EQ(1, stmt2.column_count); + + ASSERT_EQ(conn.prepared_statements, &stmt2); + + err = trilogy_stmt_close(&conn, &stmt2); + ASSERT_OK(err); + ASSERT_EQ(conn.prepared_statements, &stmt); err = trilogy_stmt_close(&conn, &stmt); ASSERT_OK(err); + ASSERT_EQ(conn.prepared_statements, NULL); + + trilogy_free(&conn); + PASS(); +} + +TEST test_blocking_stmt_conn_close() +{ + trilogy_conn_t conn; + + connect_conn(&conn); + + const char *query = "SELECT ?"; + trilogy_stmt_t stmt; + + int err = trilogy_stmt_prepare(&conn, query, strlen(query), &stmt); + ASSERT_OK(err); + ASSERT(stmt.connection); + + ASSERT_EQ(1, stmt.parameter_count); + + trilogy_column_packet_t param; + err = trilogy_read_full_column(&conn, ¶m); + ASSERT_OK(err); + + ASSERT_EQ(1, stmt.column_count); + + trilogy_column_packet_t column_def; + err = trilogy_read_full_column(&conn, &column_def); + ASSERT_OK(err); + + const char *query2 = "SELECT YEAR('2022-01-31')"; + trilogy_stmt_t stmt2; + + err = trilogy_stmt_prepare(&conn, query2, strlen(query2), &stmt2); + ASSERT_OK(err); + ASSERT(stmt2.connection); + + ASSERT_EQ(0, stmt2.parameter_count); + + trilogy_column_packet_t param2; + err = trilogy_read_full_column(&conn, ¶m2); + ASSERT_OK(err); + + ASSERT_EQ(1, stmt2.column_count); trilogy_free(&conn); + + ASSERT(stmt.connection == NULL); + ASSERT(stmt2.connection == NULL); + PASS(); } @@ -881,6 +965,7 @@ int blocking_test() RUN_TEST(test_blocking_stmt_execute_year); RUN_TEST(test_blocking_stmt_reset); RUN_TEST(test_blocking_stmt_close); + RUN_TEST(test_blocking_stmt_conn_close); return 0; }