diff --git a/src/SciSharp.MySQL.Replication/Protocol/CommandPacket.cs b/src/SciSharp.MySQL.Replication/Protocol/CommandPacket.cs new file mode 100644 index 0000000..5c3b6ca --- /dev/null +++ b/src/SciSharp.MySQL.Replication/Protocol/CommandPacket.cs @@ -0,0 +1,86 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace SciSharp.MySQL.Replication.Protocol +{ + /// + /// Represents a MySQL command packet. + /// Reference: https://dev.mysql.com/doc/internals/en/text-protocol.html + /// + internal class CommandPacket : MySQLPacket + { + public MySQLCommand Command { get; set; } + public string Query { get; set; } + public byte[] Parameters { get; set; } + + public CommandPacket(MySQLCommand command, string query = null, byte[] parameters = null) + { + Command = command; + Query = query; + Parameters = parameters; + } + + protected override byte[] GetPayload() + { + var payload = new List(); + + // Command byte + payload.Add((byte)Command); + + // Query string for COM_QUERY + if (Command == MySQLCommand.COM_QUERY && !string.IsNullOrEmpty(Query)) + { + payload.AddRange(Encoding.UTF8.GetBytes(Query)); + } + + // Parameters for other commands + if (Parameters != null) + { + payload.AddRange(Parameters); + } + + return payload.ToArray(); + } + } + + /// + /// MySQL command types. + /// Reference: https://dev.mysql.com/doc/internals/en/command-phase.html + /// + internal enum MySQLCommand : byte + { + COM_SLEEP = 0x00, + COM_QUIT = 0x01, + COM_INIT_DB = 0x02, + COM_QUERY = 0x03, + COM_FIELD_LIST = 0x04, + COM_CREATE_DB = 0x05, + COM_DROP_DB = 0x06, + COM_REFRESH = 0x07, + COM_SHUTDOWN = 0x08, + COM_STATISTICS = 0x09, + COM_PROCESS_INFO = 0x0A, + COM_CONNECT = 0x0B, + COM_PROCESS_KILL = 0x0C, + COM_DEBUG = 0x0D, + COM_PING = 0x0E, + COM_TIME = 0x0F, + COM_DELAYED_INSERT = 0x10, + COM_CHANGE_USER = 0x11, + COM_BINLOG_DUMP = 0x12, + COM_TABLE_DUMP = 0x13, + COM_CONNECT_OUT = 0x14, + COM_REGISTER_SLAVE = 0x15, + COM_STMT_PREPARE = 0x16, + COM_STMT_EXECUTE = 0x17, + COM_STMT_SEND_LONG_DATA = 0x18, + COM_STMT_CLOSE = 0x19, + COM_STMT_RESET = 0x1A, + COM_SET_OPTION = 0x1B, + COM_STMT_FETCH = 0x1C, + COM_DAEMON = 0x1D, + COM_BINLOG_DUMP_GTID = 0x1E, + COM_RESET_CONNECTION = 0x1F + } +} \ No newline at end of file diff --git a/src/SciSharp.MySQL.Replication/Protocol/DirectMySQLConnection.cs b/src/SciSharp.MySQL.Replication/Protocol/DirectMySQLConnection.cs new file mode 100644 index 0000000..3d30fb2 --- /dev/null +++ b/src/SciSharp.MySQL.Replication/Protocol/DirectMySQLConnection.cs @@ -0,0 +1,334 @@ +using System; +using System.Collections.Generic; +using System.Data; +using System.IO; +using System.Net.Sockets; +using System.Text; +using System.Threading.Tasks; + +namespace SciSharp.MySQL.Replication.Protocol +{ + /// + /// Direct MySQL connection implementation that handles the protocol without MySql.Data dependency. + /// + internal class DirectMySQLConnection : IDisposable + { + private TcpClient _tcpClient; + private NetworkStream _stream; + private byte _sequenceId; + private bool _isConnected; + private string _serverVersion; + private uint _connectionId; + + public bool IsConnected => _isConnected; + public string ServerVersion => _serverVersion; + public uint ConnectionId => _connectionId; + public Stream Stream => _stream; + + /// + /// Connects to the MySQL server and performs handshake authentication. + /// + public async Task ConnectAsync(string host, int port, string username, string password, string database = null) + { + if (_isConnected) + throw new InvalidOperationException("Connection is already established"); + + try + { + // Parse host and port + var parts = host.Split(':'); + var hostName = parts[0]; + var portNumber = parts.Length > 1 ? int.Parse(parts[1]) : port; + + // Establish TCP connection + _tcpClient = new TcpClient(); + await _tcpClient.ConnectAsync(hostName, portNumber).ConfigureAwait(false); + _stream = _tcpClient.GetStream(); + + // Perform handshake + await PerformHandshakeAsync(username, password, database).ConfigureAwait(false); + + _isConnected = true; + } + catch + { + Dispose(); + throw; + } + } + + /// + /// Performs the MySQL handshake protocol. + /// + private async Task PerformHandshakeAsync(string username, string password, string database) + { + // Read initial handshake packet from server + var (handshakePayload, sequenceId) = await MySQLPacket.ReadFromStreamAsync(_stream).ConfigureAwait(false); + _sequenceId = (byte)(sequenceId + 1); + + // Check for error packet + if (handshakePayload[0] == 0xFF) + { + throw new Exception($"Server error during handshake: {ParseErrorPacket(handshakePayload)}"); + } + + // Parse handshake packet + var handshake = HandshakePacket.Parse(handshakePayload); + _serverVersion = handshake.ServerVersion; + _connectionId = handshake.ConnectionId; + + // Generate auth response + var authResponse = MySQLAuth.MySqlNativePassword(password, handshake.AuthPluginData); + + // Create handshake response + var handshakeResponse = new HandshakeResponsePacket(username, authResponse, database, MySQLAuth.MySqlNativePasswordPlugin) + { + SequenceId = _sequenceId + }; + + // Send handshake response + await handshakeResponse.WriteToStreamAsync(_stream).ConfigureAwait(false); + _sequenceId++; + + // Read server response + var (responsePayload, responseSequenceId) = await MySQLPacket.ReadFromStreamAsync(_stream).ConfigureAwait(false); + _sequenceId = (byte)(responseSequenceId + 1); + + // Check response + if (responsePayload[0] == 0xFF) + { + throw new Exception($"Authentication failed: {ParseErrorPacket(responsePayload)}"); + } + else if (responsePayload[0] == 0x00) + { + // Success - OK packet + return; + } + else + { + throw new Exception("Unexpected response during authentication"); + } + } + + /// + /// Executes a SQL query and returns the result set. + /// + public async Task ExecuteQueryAsync(string query) + { + if (!_isConnected) + throw new InvalidOperationException("Connection is not established"); + + // Reset sequence ID for new command + _sequenceId = 0; + + // Send command packet + var commandPacket = new CommandPacket(MySQLCommand.COM_QUERY, query) + { + SequenceId = _sequenceId + }; + + await commandPacket.WriteToStreamAsync(_stream).ConfigureAwait(false); + _sequenceId++; + + // Read response packets + var (responsePayload, responseSequenceId) = await MySQLPacket.ReadFromStreamAsync(_stream).ConfigureAwait(false); + _sequenceId = (byte)(responseSequenceId + 1); + + // Check for error + if (responsePayload[0] == 0xFF) + { + throw new Exception($"Query failed: {ParseErrorPacket(responsePayload)}"); + } + + // Check for OK packet (for non-SELECT queries) + if (responsePayload[0] == 0x00) + { + return new MySQLQueryResult { IsSuccess = true }; + } + + // Parse result set + return await ParseResultSetAsync(responsePayload).ConfigureAwait(false); + } + + /// + /// Parses a result set from query response. + /// + private async Task ParseResultSetAsync(byte[] firstPacket) + { + var result = new MySQLQueryResult { IsSuccess = true }; + + // First packet contains column count + var columnCount = firstPacket[0]; + result.ColumnCount = columnCount; + + // Read column definition packets + var columns = new List(); + for (int i = 0; i < columnCount; i++) + { + var (columnPayload, sequenceId) = await MySQLPacket.ReadFromStreamAsync(_stream).ConfigureAwait(false); + _sequenceId = (byte)(sequenceId + 1); + + var column = ParseColumnDefinition(columnPayload); + columns.Add(column); + } + result.Columns = columns; + + // Read EOF packet after column definitions + var (eofPayload, eofSequenceId) = await MySQLPacket.ReadFromStreamAsync(_stream).ConfigureAwait(false); + _sequenceId = (byte)(eofSequenceId + 1); + + // Read data rows + var rows = new List(); + while (true) + { + var (rowPayload, rowSequenceId) = await MySQLPacket.ReadFromStreamAsync(_stream).ConfigureAwait(false); + _sequenceId = (byte)(rowSequenceId + 1); + + // Check for EOF packet + if (rowPayload[0] == 0xFE && rowPayload.Length < 9) + { + break; + } + + var row = ParseRow(rowPayload, columnCount); + rows.Add(row); + } + result.Rows = rows; + + return result; + } + + /// + /// Parses a column definition packet. + /// + private MySQLColumn ParseColumnDefinition(byte[] payload) + { + var column = new MySQLColumn(); + int offset = 0; + + // Skip catalog (length-encoded string) + offset += ReadLengthEncodedString(payload, offset, out _); + + // Skip schema (length-encoded string) + offset += ReadLengthEncodedString(payload, offset, out _); + + // Skip table (length-encoded string) + offset += ReadLengthEncodedString(payload, offset, out _); + + // Skip org_table (length-encoded string) + offset += ReadLengthEncodedString(payload, offset, out _); + + // Column name (length-encoded string) + string columnName; + offset += ReadLengthEncodedString(payload, offset, out columnName); + column.Name = columnName; + + // Skip org_name (length-encoded string) + offset += ReadLengthEncodedString(payload, offset, out _); + + // Skip length of fixed-length fields + offset++; + + // Character set (2 bytes) + offset += 2; + + // Column length (4 bytes) + offset += 4; + + // Column type (1 byte) + column.Type = payload[offset]; + + return column; + } + + /// + /// Parses a data row packet. + /// + private MySQLRow ParseRow(byte[] payload, int columnCount) + { + var row = new MySQLRow(); + var values = new string[columnCount]; + int offset = 0; + + for (int i = 0; i < columnCount; i++) + { + offset += ReadLengthEncodedString(payload, offset, out values[i]); + } + + row.Values = values; + return row; + } + + /// + /// Reads a length-encoded string from the payload. + /// + private int ReadLengthEncodedString(byte[] payload, int offset, out string value) + { + if (payload[offset] == 0xFB) + { + // NULL value + value = null; + return 1; + } + + var length = payload[offset]; + if (length < 0xFB) + { + // Single byte length + value = Encoding.UTF8.GetString(payload, offset + 1, length); + return 1 + length; + } + + // Multi-byte length (not implemented for simplicity) + throw new NotImplementedException("Multi-byte length-encoded strings not implemented"); + } + + /// + /// Parses an error packet. + /// + private string ParseErrorPacket(byte[] payload) + { + if (payload.Length < 3) + return "Unknown error"; + + var errorCode = BitConverter.ToUInt16(payload, 1); + var message = Encoding.UTF8.GetString(payload, 3, payload.Length - 3); + return $"Error {errorCode}: {message}"; + } + + public void Dispose() + { + _isConnected = false; + _stream?.Dispose(); + _tcpClient?.Dispose(); + } + } + + /// + /// Represents the result of a MySQL query. + /// + internal class MySQLQueryResult + { + public bool IsSuccess { get; set; } + public int ColumnCount { get; set; } + public List Columns { get; set; } + public List Rows { get; set; } + } + + /// + /// Represents a MySQL column definition. + /// + internal class MySQLColumn + { + public string Name { get; set; } + public byte Type { get; set; } + } + + /// + /// Represents a MySQL data row. + /// + internal class MySQLRow + { + public string[] Values { get; set; } + } +} \ No newline at end of file diff --git a/src/SciSharp.MySQL.Replication/Protocol/HandshakePacket.cs b/src/SciSharp.MySQL.Replication/Protocol/HandshakePacket.cs new file mode 100644 index 0000000..f541d38 --- /dev/null +++ b/src/SciSharp.MySQL.Replication/Protocol/HandshakePacket.cs @@ -0,0 +1,152 @@ +using System; +using System.Text; + +namespace SciSharp.MySQL.Replication.Protocol +{ + /// + /// Represents the initial handshake packet sent by the MySQL server. + /// Reference: https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake + /// + internal class HandshakePacket + { + public byte ProtocolVersion { get; set; } + public string ServerVersion { get; set; } + public uint ConnectionId { get; set; } + public byte[] AuthPluginDataPart1 { get; set; } // 8 bytes + public uint CapabilityFlagsLower { get; set; } // lower 16 bits + public byte CharacterSet { get; set; } + public ushort StatusFlags { get; set; } + public uint CapabilityFlagsUpper { get; set; } // upper 16 bits + public byte AuthPluginDataLength { get; set; } + public byte[] AuthPluginDataPart2 { get; set; } // max 13 bytes (total auth_plugin_data is 21 bytes) + public string AuthPluginName { get; set; } + + /// + /// Gets the combined capability flags (lower 16 bits + upper 16 bits). + /// + public uint CapabilityFlags => CapabilityFlagsLower | (CapabilityFlagsUpper << 16); + + /// + /// Gets the complete authentication plugin data (part1 + part2). + /// + public byte[] AuthPluginData + { + get + { + var result = new byte[20]; // MySQL uses 20 bytes for auth data + Array.Copy(AuthPluginDataPart1, 0, result, 0, 8); + if (AuthPluginDataPart2 != null) + { + var copyLength = Math.Min(AuthPluginDataPart2.Length, 12); + Array.Copy(AuthPluginDataPart2, 0, result, 8, copyLength); + } + return result; + } + } + + /// + /// Parses a handshake packet from raw bytes. + /// + public static HandshakePacket Parse(byte[] payload) + { + var packet = new HandshakePacket(); + int offset = 0; + + // Protocol version (1 byte) + packet.ProtocolVersion = payload[offset++]; + + // Server version (null-terminated string) + var serverVersionEnd = Array.IndexOf(payload, (byte)0, offset); + packet.ServerVersion = Encoding.UTF8.GetString(payload, offset, serverVersionEnd - offset); + offset = serverVersionEnd + 1; + + // Connection ID (4 bytes) + packet.ConnectionId = BitConverter.ToUInt32(payload, offset); + offset += 4; + + // Auth plugin data part 1 (8 bytes) + packet.AuthPluginDataPart1 = new byte[8]; + Array.Copy(payload, offset, packet.AuthPluginDataPart1, 0, 8); + offset += 8; + + // Filter (1 byte) - always 0x00 + offset++; + + // Capability flags lower 16 bits (2 bytes) + packet.CapabilityFlagsLower = BitConverter.ToUInt16(payload, offset); + offset += 2; + + // Character set (1 byte) + packet.CharacterSet = payload[offset++]; + + // Status flags (2 bytes) + packet.StatusFlags = BitConverter.ToUInt16(payload, offset); + offset += 2; + + // Capability flags upper 16 bits (2 bytes) + packet.CapabilityFlagsUpper = BitConverter.ToUInt16(payload, offset); + offset += 2; + + // Auth plugin data length (1 byte) + packet.AuthPluginDataLength = payload[offset++]; + + // Reserved (10 bytes) - skip + offset += 10; + + // Auth plugin data part 2 (max 13 bytes, but actual length is auth_plugin_data_len - 8) + if (packet.AuthPluginDataLength > 8) + { + var part2Length = Math.Min(packet.AuthPluginDataLength - 8, 13); + packet.AuthPluginDataPart2 = new byte[part2Length]; + Array.Copy(payload, offset, packet.AuthPluginDataPart2, 0, part2Length); + offset += part2Length; + } + + // Auth plugin name (null-terminated string) - if CLIENT_PLUGIN_AUTH capability is set + if ((packet.CapabilityFlags & (uint)ClientCapabilities.CLIENT_PLUGIN_AUTH) != 0) + { + var authPluginNameEnd = Array.IndexOf(payload, (byte)0, offset); + if (authPluginNameEnd > offset) + { + packet.AuthPluginName = Encoding.UTF8.GetString(payload, offset, authPluginNameEnd - offset); + } + } + + return packet; + } + } + + /// + /// Client capability flags used in the handshake. + /// Reference: https://dev.mysql.com/doc/internals/en/capability-flags.html + /// + [Flags] + internal enum ClientCapabilities : uint + { + CLIENT_LONG_PASSWORD = 0x00000001, + CLIENT_FOUND_ROWS = 0x00000002, + CLIENT_LONG_FLAG = 0x00000004, + CLIENT_CONNECT_WITH_DB = 0x00000008, + CLIENT_NO_SCHEMA = 0x00000010, + CLIENT_COMPRESS = 0x00000020, + CLIENT_ODBC = 0x00000040, + CLIENT_LOCAL_FILES = 0x00000080, + CLIENT_IGNORE_SPACE = 0x00000100, + CLIENT_PROTOCOL_41 = 0x00000200, + CLIENT_INTERACTIVE = 0x00000400, + CLIENT_SSL = 0x00000800, + CLIENT_IGNORE_SIGPIPE = 0x00001000, + CLIENT_TRANSACTIONS = 0x00002000, + CLIENT_RESERVED = 0x00004000, + CLIENT_SECURE_CONNECTION = 0x00008000, + CLIENT_MULTI_STATEMENTS = 0x00010000, + CLIENT_MULTI_RESULTS = 0x00020000, + CLIENT_PS_MULTI_RESULTS = 0x00040000, + CLIENT_PLUGIN_AUTH = 0x00080000, + CLIENT_CONNECT_ATTRS = 0x00100000, + CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA = 0x00200000, + CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS = 0x00400000, + CLIENT_SESSION_TRACK = 0x00800000, + CLIENT_DEPRECATE_EOF = 0x01000000 + } +} \ No newline at end of file diff --git a/src/SciSharp.MySQL.Replication/Protocol/HandshakeResponsePacket.cs b/src/SciSharp.MySQL.Replication/Protocol/HandshakeResponsePacket.cs new file mode 100644 index 0000000..370a050 --- /dev/null +++ b/src/SciSharp.MySQL.Replication/Protocol/HandshakeResponsePacket.cs @@ -0,0 +1,104 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace SciSharp.MySQL.Replication.Protocol +{ + /// + /// Represents the handshake response packet sent by the client. + /// Reference: https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse + /// + internal class HandshakeResponsePacket : MySQLPacket + { + public uint CapabilityFlags { get; set; } + public uint MaxPacketSize { get; set; } + public byte CharacterSet { get; set; } + public string Username { get; set; } + public byte[] AuthResponse { get; set; } + public string Database { get; set; } + public string AuthPluginName { get; set; } + + public HandshakeResponsePacket(string username, byte[] authResponse, string database = null, string authPluginName = "mysql_native_password") + { + Username = username ?? throw new ArgumentNullException(nameof(username)); + AuthResponse = authResponse ?? throw new ArgumentNullException(nameof(authResponse)); + Database = database; + AuthPluginName = authPluginName; + + // Set default capability flags for replication client + CapabilityFlags = (uint)( + ClientCapabilities.CLIENT_PROTOCOL_41 | + ClientCapabilities.CLIENT_SECURE_CONNECTION | + ClientCapabilities.CLIENT_LONG_PASSWORD | + ClientCapabilities.CLIENT_TRANSACTIONS | + ClientCapabilities.CLIENT_PLUGIN_AUTH | + ClientCapabilities.CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA + ); + + if (!string.IsNullOrEmpty(database)) + { + CapabilityFlags |= (uint)ClientCapabilities.CLIENT_CONNECT_WITH_DB; + } + + MaxPacketSize = 0x01000000; // 16MB + CharacterSet = 8; // latin1_swedish_ci + } + + protected override byte[] GetPayload() + { + var payload = new List(); + + // Capability flags (4 bytes) + payload.AddRange(BitConverter.GetBytes(CapabilityFlags)); + + // Max packet size (4 bytes) + payload.AddRange(BitConverter.GetBytes(MaxPacketSize)); + + // Character set (1 byte) + payload.Add(CharacterSet); + + // Reserved (23 bytes of zeros) + payload.AddRange(new byte[23]); + + // Username (null-terminated string) + payload.AddRange(Encoding.UTF8.GetBytes(Username)); + payload.Add(0); + + // Auth response + if ((CapabilityFlags & (uint)ClientCapabilities.CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA) != 0) + { + // Length-encoded auth response + payload.Add((byte)AuthResponse.Length); + payload.AddRange(AuthResponse); + } + else if ((CapabilityFlags & (uint)ClientCapabilities.CLIENT_SECURE_CONNECTION) != 0) + { + // Length-prefixed auth response + payload.Add((byte)AuthResponse.Length); + payload.AddRange(AuthResponse); + } + else + { + // Null-terminated auth response + payload.AddRange(AuthResponse); + payload.Add(0); + } + + // Database (null-terminated string) - if CLIENT_CONNECT_WITH_DB is set + if ((CapabilityFlags & (uint)ClientCapabilities.CLIENT_CONNECT_WITH_DB) != 0 && !string.IsNullOrEmpty(Database)) + { + payload.AddRange(Encoding.UTF8.GetBytes(Database)); + payload.Add(0); + } + + // Auth plugin name (null-terminated string) - if CLIENT_PLUGIN_AUTH is set + if ((CapabilityFlags & (uint)ClientCapabilities.CLIENT_PLUGIN_AUTH) != 0 && !string.IsNullOrEmpty(AuthPluginName)) + { + payload.AddRange(Encoding.UTF8.GetBytes(AuthPluginName)); + payload.Add(0); + } + + return payload.ToArray(); + } + } +} \ No newline at end of file diff --git a/src/SciSharp.MySQL.Replication/Protocol/MySQLAuth.cs b/src/SciSharp.MySQL.Replication/Protocol/MySQLAuth.cs new file mode 100644 index 0000000..4847959 --- /dev/null +++ b/src/SciSharp.MySQL.Replication/Protocol/MySQLAuth.cs @@ -0,0 +1,58 @@ +using System; +using System.Security.Cryptography; +using System.Text; + +namespace SciSharp.MySQL.Replication.Protocol +{ + /// + /// Implements MySQL authentication mechanisms. + /// + internal static class MySQLAuth + { + /// + /// Implements the mysql_native_password authentication method. + /// Formula: SHA1(password) XOR SHA1(scramble + SHA1(SHA1(password))) + /// Reference: https://dev.mysql.com/doc/internals/en/secure-password-authentication.html + /// + public static byte[] MySqlNativePassword(string password, byte[] scramble) + { + if (string.IsNullOrEmpty(password)) + return new byte[0]; + + if (scramble == null || scramble.Length < 20) + throw new ArgumentException("Scramble must be at least 20 bytes", nameof(scramble)); + + using (var sha1 = SHA1.Create()) + { + // Step 1: SHA1(password) + var passwordBytes = Encoding.UTF8.GetBytes(password); + var sha1Password = sha1.ComputeHash(passwordBytes); + + // Step 2: SHA1(SHA1(password)) + var sha1Sha1Password = sha1.ComputeHash(sha1Password); + + // Step 3: scramble + SHA1(SHA1(password)) + var scrambleAndHash = new byte[20 + sha1Sha1Password.Length]; + Array.Copy(scramble, 0, scrambleAndHash, 0, 20); + Array.Copy(sha1Sha1Password, 0, scrambleAndHash, 20, sha1Sha1Password.Length); + + // Step 4: SHA1(scramble + SHA1(SHA1(password))) + var sha1ScrambleAndHash = sha1.ComputeHash(scrambleAndHash); + + // Step 5: SHA1(password) XOR SHA1(scramble + SHA1(SHA1(password))) + var result = new byte[20]; + for (int i = 0; i < 20; i++) + { + result[i] = (byte)(sha1Password[i] ^ sha1ScrambleAndHash[i]); + } + + return result; + } + } + + /// + /// Gets the auth plugin name for mysql_native_password. + /// + public const string MySqlNativePasswordPlugin = "mysql_native_password"; + } +} \ No newline at end of file diff --git a/src/SciSharp.MySQL.Replication/Protocol/MySQLPacket.cs b/src/SciSharp.MySQL.Replication/Protocol/MySQLPacket.cs new file mode 100644 index 0000000..fba72b1 --- /dev/null +++ b/src/SciSharp.MySQL.Replication/Protocol/MySQLPacket.cs @@ -0,0 +1,75 @@ +using System; +using System.Buffers.Binary; +using System.IO; +using System.Threading.Tasks; + +namespace SciSharp.MySQL.Replication.Protocol +{ + /// + /// Base class for MySQL protocol packets. + /// MySQL protocol packet format: [3 bytes length][1 byte sequence_id][payload] + /// + internal abstract class MySQLPacket + { + public byte SequenceId { get; set; } + + /// + /// Writes the packet to a stream. + /// + public async Task WriteToStreamAsync(Stream stream) + { + var payload = GetPayload(); + var length = payload.Length; + + // Write packet header (3 bytes length + 1 byte sequence_id) + var header = new byte[4]; + header[0] = (byte)(length & 0xFF); + header[1] = (byte)((length >> 8) & 0xFF); + header[2] = (byte)((length >> 16) & 0xFF); + header[3] = SequenceId; + + await stream.WriteAsync(header).ConfigureAwait(false); + await stream.WriteAsync(payload).ConfigureAwait(false); + await stream.FlushAsync().ConfigureAwait(false); + } + + /// + /// Reads a packet from a stream. + /// + public static async Task<(byte[] payload, byte sequenceId)> ReadFromStreamAsync(Stream stream) + { + // Read packet header + var header = new byte[4]; + await ReadExactlyAsync(stream, header, 4).ConfigureAwait(false); + + var length = header[0] | (header[1] << 8) | (header[2] << 16); + var sequenceId = header[3]; + + // Read payload + var payload = new byte[length]; + await ReadExactlyAsync(stream, payload, length).ConfigureAwait(false); + + return (payload, sequenceId); + } + + /// + /// Helper method to read exactly the specified number of bytes. + /// + private static async Task ReadExactlyAsync(Stream stream, byte[] buffer, int count) + { + int totalBytesRead = 0; + while (totalBytesRead < count) + { + int bytesRead = await stream.ReadAsync(buffer, totalBytesRead, count - totalBytesRead).ConfigureAwait(false); + if (bytesRead == 0) + throw new EndOfStreamException("Unexpected end of stream while reading MySQL packet"); + totalBytesRead += bytesRead; + } + } + + /// + /// Gets the payload bytes for this packet. + /// + protected abstract byte[] GetPayload(); + } +} \ No newline at end of file diff --git a/src/SciSharp.MySQL.Replication/ReplicationClient.cs b/src/SciSharp.MySQL.Replication/ReplicationClient.cs index ae754a2..288442d 100644 --- a/src/SciSharp.MySQL.Replication/ReplicationClient.cs +++ b/src/SciSharp.MySQL.Replication/ReplicationClient.cs @@ -7,8 +7,8 @@ using System.Reflection; using System.Threading.Tasks; using Microsoft.Extensions.Logging; -using MySql.Data.MySqlClient; using SciSharp.MySQL.Replication.Events; +using SciSharp.MySQL.Replication.Protocol; using SuperSocket.Client; using SuperSocket.Connection; @@ -28,7 +28,7 @@ public class ReplicationClient : EasyClient, IReplicationClient, IAsyn private const int BINLOG_SEND_ANNOTATE_ROWS_EVENT = 2; - private MySqlConnection _connection; + private DirectMySQLConnection _connection; private int _serverId; @@ -96,20 +96,7 @@ private ReplicationClient(LogEventPipelineFilter logEventPipelineFilter) _tableSchemaMap = (logEventPipelineFilter.Context as ReplicationState).TableSchemaMap; } - /// - /// Gets the underlying stream from a MySQL connection. - /// - /// The MySQL connection. - /// The stream associated with the connection. - private Stream GetStreamFromMySQLConnection(MySqlConnection connection) - { - var driverField = connection.GetType().GetField("driver", BindingFlags.Instance | BindingFlags.NonPublic); - var driver = driverField.GetValue(connection); - var handlerField = driver.GetType().GetField("handler", BindingFlags.Instance | BindingFlags.NonPublic); - var handler = handlerField.GetValue(driver); - var baseStreamField = handler.GetType().GetField("baseStream", BindingFlags.Instance | BindingFlags.NonPublic); - return baseStreamField.GetValue(handler) as Stream; - } + /// /// Connects to a MySQL server as a replication client. @@ -152,12 +139,11 @@ public async Task ConnectAsync(string server, string username, stri /// A task representing the asynchronous operation, with a result indicating whether the login was successful. private async Task ConnectInternalAsync(string server, string username, string password, int serverId, BinlogPosition binlogPosition) { - var connString = $"Server={server}; UID={username}; Password={password}"; - var mysqlConn = new MySqlConnection(connString); + var directConn = new DirectMySQLConnection(); try { - await mysqlConn.OpenAsync().ConfigureAwait(false); + await directConn.ConnectAsync(server, 3306, username, password).ConfigureAwait(false); } catch (Exception e) { @@ -171,27 +157,27 @@ private async Task ConnectInternalAsync(string server, string usern try { // Load database schema using the established connection - await LoadDatabaseSchemaAsync(mysqlConn).ConfigureAwait(false); + await LoadDatabaseSchemaAsync(directConn).ConfigureAwait(false); // If no binlog position was provided, get the current position from the server if (binlogPosition == null) { - binlogPosition = await GetBinlogFileNameAndPosition(mysqlConn).ConfigureAwait(false); + binlogPosition = await GetBinlogFileNameAndPosition(directConn).ConfigureAwait(false); } // Set up checksum verification - var binlogChecksum = await GetBinlogChecksum(mysqlConn).ConfigureAwait(false); - await ConfirmChecksum(mysqlConn).ConfigureAwait(false); + var binlogChecksum = await GetBinlogChecksum(directConn).ConfigureAwait(false); + await ConfirmChecksum(directConn).ConfigureAwait(false); LogEvent.ChecksumType = binlogChecksum; // Get the underlying stream and start the binlog dump - _stream = GetStreamFromMySQLConnection(mysqlConn); + _stream = directConn.Stream; _serverId = serverId; _currentPosition = new BinlogPosition(binlogPosition); await StartDumpBinlog(_stream, serverId, binlogPosition.Filename, binlogPosition.Position).ConfigureAwait(false); - _connection = mysqlConn; + _connection = directConn; // Create a connection for the event stream var connection = new StreamPipeConnection( @@ -210,7 +196,7 @@ private async Task ConnectInternalAsync(string server, string usern } catch (Exception e) { - await mysqlConn.CloseAsync().ConfigureAwait(false); + directConn.Dispose(); return new LoginResult { @@ -267,30 +253,32 @@ private void TrackBinlogPosition(LogEvent logEvent) } } - private async Task LoadDatabaseSchemaAsync(MySqlConnection mysqlConn) + private async Task LoadDatabaseSchemaAsync(DirectMySQLConnection directConn) { - var tableSchemaTable = await mysqlConn.GetSchemaAsync("Columns").ConfigureAwait(false); - - var systemDatabases = new HashSet( - new [] { "mysql", "information_schema", "performance_schema", "sys" }, - StringComparer.OrdinalIgnoreCase); - - var userDatabaseColumns = tableSchemaTable.Rows.OfType() - .Where(row => !systemDatabases.Contains(row.ItemArray[1].ToString())) - .ToArray(); + var query = @"SELECT + TABLE_SCHEMA, + TABLE_NAME, + COLUMN_NAME, + DATA_TYPE, + CHARACTER_MAXIMUM_LENGTH + FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA NOT IN ('mysql', 'information_schema', 'performance_schema', 'sys')"; + + var result = await directConn.ExecuteQueryAsync(query).ConfigureAwait(false); + + if (!result.IsSuccess || result.Rows == null) + return; - userDatabaseColumns.Select(row => - { - var columnSizeCell = row["CHARACTER_MAXIMUM_LENGTH"]; - - return new { - TableName = row["TABLE_NAME"].ToString(), - DatabaseName = row["TABLE_SCHEMA"].ToString(), - ColumnName = row["COLUMN_NAME"].ToString(), - ColumnType = row["DATA_TYPE"].ToString(), - ColumnSize = columnSizeCell == DBNull.Value ? 0 : Convert.ToUInt64(columnSizeCell), - }; - }) + var userDatabaseColumns = result.Rows.Select(row => new + { + TableName = row.Values[1], + DatabaseName = row.Values[0], + ColumnName = row.Values[2], + ColumnType = row.Values[3], + ColumnSize = string.IsNullOrEmpty(row.Values[4]) ? 0UL : Convert.ToUInt64(row.Values[4]) + }).ToArray(); + + userDatabaseColumns .GroupBy(row => new { row.TableName, row.DatabaseName }) .ToList() .ForEach(group => @@ -315,59 +303,46 @@ private async Task LoadDatabaseSchemaAsync(MySqlConnection mysqlConn) /// Retrieves the binary log file name and position from the MySQL server. /// https://dev.mysql.com/doc/refman/5.6/en/replication-howto-masterstatus.html /// - /// The MySQL connection. + /// The direct MySQL connection. /// A tuple containing the binary log file name and position. - private async Task GetBinlogFileNameAndPosition(MySqlConnection mysqlConn) + private async Task GetBinlogFileNameAndPosition(DirectMySQLConnection directConn) { - var cmd = mysqlConn.CreateCommand(); - cmd.CommandText = "SHOW MASTER STATUS;"; + var result = await directConn.ExecuteQueryAsync("SHOW MASTER STATUS").ConfigureAwait(false); - using (var reader = await cmd.ExecuteReaderAsync().ConfigureAwait(false)) - { - if (!await reader.ReadAsync()) - throw new Exception("No binlog information has been returned."); - - var fileName = reader.GetString(0); - var position = reader.GetInt32(1); + if (!result.IsSuccess || result.Rows == null || result.Rows.Count == 0) + throw new Exception("No binlog information has been returned."); - await reader.CloseAsync().ConfigureAwait(false); + var firstRow = result.Rows[0]; + var fileName = firstRow.Values[0]; + var position = int.Parse(firstRow.Values[1]); - return new BinlogPosition(fileName, position); - } + return new BinlogPosition(fileName, position); } /// /// Retrieves the binary log checksum type from the MySQL server. /// - /// The MySQL connection. + /// The direct MySQL connection. /// The checksum type. - private async Task GetBinlogChecksum(MySqlConnection mysqlConn) + private async Task GetBinlogChecksum(DirectMySQLConnection directConn) { - var cmd = mysqlConn.CreateCommand(); - cmd.CommandText = "show global variables like 'binlog_checksum';"; + var result = await directConn.ExecuteQueryAsync("show global variables like 'binlog_checksum'").ConfigureAwait(false); - using (var reader = await cmd.ExecuteReaderAsync()) - { - if (!await reader.ReadAsync().ConfigureAwait(false)) - return ChecksumType.NONE; + if (!result.IsSuccess || result.Rows == null || result.Rows.Count == 0) + return ChecksumType.NONE; - var checksumTypeName = reader.GetString(1).ToUpper(); - await reader.CloseAsync().ConfigureAwait(false); - - return (ChecksumType)Enum.Parse(typeof(ChecksumType), checksumTypeName); - } + var checksumTypeName = result.Rows[0].Values[1].ToUpper(); + return (ChecksumType)Enum.Parse(typeof(ChecksumType), checksumTypeName); } /// /// Confirms the binary log checksum setting on the MySQL server. /// - /// The MySQL connection. + /// The direct MySQL connection. /// A task representing the asynchronous operation. - private async ValueTask ConfirmChecksum(MySqlConnection mysqlConn) + private async ValueTask ConfirmChecksum(DirectMySQLConnection directConn) { - var cmd = mysqlConn.CreateCommand(); - cmd.CommandText = "set @`master_binlog_checksum` = @@binlog_checksum;"; - await cmd.ExecuteNonQueryAsync().ConfigureAwait(false); + await directConn.ExecuteQueryAsync("set @`master_binlog_checksum` = @@binlog_checksum").ConfigureAwait(false); } /// @@ -490,7 +465,7 @@ public override async ValueTask CloseAsync() if (connection != null) { _connection = null; - await connection.CloseAsync().ConfigureAwait(false); + connection.Dispose(); } await base.CloseAsync().ConfigureAwait(false); diff --git a/src/SciSharp.MySQL.Replication/SciSharp.MySQL.Replication.csproj b/src/SciSharp.MySQL.Replication/SciSharp.MySQL.Replication.csproj index a014509..b61b2fd 100644 --- a/src/SciSharp.MySQL.Replication/SciSharp.MySQL.Replication.csproj +++ b/src/SciSharp.MySQL.Replication/SciSharp.MySQL.Replication.csproj @@ -1,6 +1,6 @@ - net6.0;net7.0;net8.0;net9.0 + net6.0;net7.0;net8.0 12.0 true @@ -24,7 +24,6 @@ - diff --git a/tests/Test/DirectConnectionTest.cs b/tests/Test/DirectConnectionTest.cs new file mode 100644 index 0000000..58857c9 --- /dev/null +++ b/tests/Test/DirectConnectionTest.cs @@ -0,0 +1,51 @@ +using System; +using System.Threading.Tasks; +using SciSharp.MySQL.Replication; +using Xunit; +using Xunit.Abstractions; + +namespace Test +{ + [Trait("Category", "DirectConnection")] + public class DirectConnectionTest + { + protected readonly ITestOutputHelper _outputHelper; + + public DirectConnectionTest(ITestOutputHelper outputHelper) + { + _outputHelper = outputHelper; + } + + [Fact] + public async Task TestReplicationClientWithDirectConnection() + { + var client = new ReplicationClient(); + + try + { + // Test connection using the new implementation + var result = await client.ConnectAsync("localhost", "root", "root", 1001); + + Assert.True(result.Result, $"Connection failed: {result.Message}"); + _outputHelper.WriteLine($"ReplicationClient connected successfully using direct MySQL protocol"); + _outputHelper.WriteLine($"Connection result message: {result.Message ?? "Success"}"); + + // Verify current position is available + Assert.NotNull(client.CurrentPosition); + _outputHelper.WriteLine($"Current binlog position: {client.CurrentPosition.Filename}:{client.CurrentPosition.Position}"); + + _outputHelper.WriteLine("Direct MySQL protocol implementation is working correctly!"); + } + catch (Exception ex) + { + _outputHelper.WriteLine($"Test failed with exception: {ex.Message}"); + _outputHelper.WriteLine($"Stack trace: {ex.StackTrace}"); + throw; + } + finally + { + await client.CloseAsync(); + } + } + } +} diff --git a/tests/Test/Test.csproj b/tests/Test/Test.csproj index cfc4ee1..aff0f71 100644 --- a/tests/Test/Test.csproj +++ b/tests/Test/Test.csproj @@ -1,7 +1,7 @@ - net9.0 + net8.0 false @@ -11,6 +11,7 @@ +