diff --git a/spec/sni_spec.cr b/spec/sni_spec.cr index 09ef868456..8f5b13b8ca 100644 --- a/spec/sni_spec.cr +++ b/spec/sni_spec.cr @@ -313,4 +313,82 @@ describe "SNI end-to-end" do tcp_server.close server_done.receive end + + it "copies verify_mode from the SNI selected context" do + # Setup SNI manager with mTLS enabled host + sni_manager = LavinMQ::SNIManager.new + mtls_host = LavinMQ::SNIHost.new("mtls.localhost") + mtls_host.tls_cert = "spec/resources/server_certificate.pem" + mtls_host.tls_key = "spec/resources/server_key.pem" + mtls_host.tls_verify_peer = true + mtls_host.tls_ca_cert = "spec/resources/ca_certificate.pem" + sni_manager.add_host(mtls_host) + + mtls_host.amqp_tls_context.verify_mode.should eq(OpenSSL::SSL::VerifyMode::PEER | OpenSSL::SSL::VerifyMode::FAIL_IF_NO_PEER_CERT) + + # Default server context without mTLS + default_ctx = OpenSSL::SSL::Context::Server.new + default_ctx.verify_mode = OpenSSL::SSL::VerifyMode::NONE + default_ctx.certificate_chain = "spec/resources/server_certificate.pem" + default_ctx.private_key = "spec/resources/server_key.pem" + + # SNI callback to switch to mTLS context for mtls.localhost + default_ctx.set_sni_callback do |hostname| + sni_manager.get_host(hostname).try(&.amqp_tls_context) + end + + # Start TLS server + tcp_server = TCPServer.new("127.0.0.1", 0) + port = tcp_server.local_address.port + + server_done = Channel(Nil).new + + spawn do + 2.times do + if client = tcp_server.accept? + begin + ssl_socket = OpenSSL::SSL::Socket::Server.new(client, default_ctx) + ssl_socket.close + rescue + # Ignore handshake errors in server + ensure + client.close + end + end + end + server_done.send(nil) + end + + # Test 1: Connection without client cert should be rejected + tcp_client1 = TCPSocket.new("127.0.0.1", port) + client_ctx1 = OpenSSL::SSL::Context::Client.new + client_ctx1.verify_mode = OpenSSL::SSL::VerifyMode::NONE + begin + expect_raises(Exception) do + ssl_client1 = OpenSSL::SSL::Socket::Client.new(tcp_client1, client_ctx1, hostname: "mtls.localhost") + # If handshake succeeds, try to read which should fail + ssl_client1.gets + end + ensure + tcp_client1.close + end + + # Test 2: Connection with valid client cert should succeed + tcp_client2 = TCPSocket.new("127.0.0.1", port) + client_ctx2 = OpenSSL::SSL::Context::Client.new + client_ctx2.verify_mode = OpenSSL::SSL::VerifyMode::NONE + client_ctx2.certificate_chain = "spec/resources/client_certificate.pem" + client_ctx2.private_key = "spec/resources/client_key.pem" + begin + ssl_client2 = OpenSSL::SSL::Socket::Client.new(tcp_client2, client_ctx2, hostname: "mtls.localhost") + # If handshake succeeds, try to read too + ssl_client2.gets + ssl_client2.close + ensure + tcp_client2.close + end + + tcp_server.close + server_done.receive + end end diff --git a/src/stdlib/openssl_sni.cr b/src/stdlib/openssl_sni.cr index 8f0c1adb6b..9c234e44a0 100644 --- a/src/stdlib/openssl_sni.cr +++ b/src/stdlib/openssl_sni.cr @@ -18,6 +18,7 @@ lib LibSSL fun ssl_ctx_callback_ctrl = SSL_CTX_callback_ctrl(ctx : SSLContext, cmd : LibC::Int, fp : Proc(Void)) : LibC::Long fun ssl_set_ssl_ctx = SSL_set_SSL_CTX(ssl : SSL, ctx : SSLContext) : SSLContext + fun ssl_set_verify = SSL_set_verify(ssl : SSL, mode : LibC::Int, callback : Void*) : Void end class OpenSSL::SSL::Context::Server @@ -62,6 +63,8 @@ class OpenSSL::SSL::Context::Server if new_context # Switch to the new SSL_CTX for this connection LibSSL.ssl_set_ssl_ctx(ssl, new_context.to_unsafe) + verify_mode = LibSSL.ssl_ctx_get_verify_mode(new_context.to_unsafe).to_i + LibSSL.ssl_set_verify(ssl, verify_mode, nil) end LibSSL::SSL_TLSEXT_ERR_OK