diff --git a/extras/lavinmq.ini b/extras/lavinmq.ini index 3384a578a6..1016ddd620 100644 --- a/extras/lavinmq.ini +++ b/extras/lavinmq.ini @@ -15,6 +15,8 @@ bind = :: bind = :: ;port = 5672 ;tls_port = 5671 +;tcp_proxy_protocol = 0 +;proxy_protocol_trusted_sources = 10.0.0.1, 192.168.0.0/24, 2001:db8::/32 [mqtt] bind = :: diff --git a/spec/ip_matcher_spec.cr b/spec/ip_matcher_spec.cr new file mode 100644 index 0000000000..ca53193f7f --- /dev/null +++ b/spec/ip_matcher_spec.cr @@ -0,0 +1,277 @@ +require "./spec_helper" + +describe LavinMQ::IPMatcher do + describe ".parse and #matches?" do + describe "IPv4 exact IP matching" do + it "matches exact IPv4 address" do + matcher = LavinMQ::IPMatcher.parse("192.168.1.1") + matcher.matches?("192.168.1.1").should be_true + end + + it "doesn't match different IPv4 address" do + matcher = LavinMQ::IPMatcher.parse("192.168.1.1") + matcher.matches?("192.168.1.2").should be_false + matcher.matches?("192.168.2.1").should be_false + matcher.matches?("10.0.0.1").should be_false + end + + it "handles loopback address" do + matcher = LavinMQ::IPMatcher.parse("127.0.0.1") + matcher.matches?("127.0.0.1").should be_true + matcher.matches?("127.0.0.2").should be_false + end + end + + describe "IPv6 exact IP matching" do + it "matches exact IPv6 address" do + matcher = LavinMQ::IPMatcher.parse("2001:db8::1") + matcher.matches?("2001:db8::1").should be_true + end + + it "doesn't match different IPv6 address" do + matcher = LavinMQ::IPMatcher.parse("2001:db8::1") + matcher.matches?("2001:db8::2").should be_false + matcher.matches?("2001:db9::1").should be_false + end + + it "handles IPv6 loopback" do + matcher = LavinMQ::IPMatcher.parse("::1") + matcher.matches?("::1").should be_true + matcher.matches?("::2").should be_false + end + + it "exact match requires same string representation" do + # Exact IPs use string comparison for performance + # Different representations of the same IPv6 address won't match + matcher = LavinMQ::IPMatcher.parse("2001:0db8:0000:0000:0000:0000:0000:0001") + matcher.matches?("2001:0db8:0000:0000:0000:0000:0000:0001").should be_true + matcher.matches?("2001:db8::1").should be_false + + # If you need to match different representations, use CIDR /128 + cidr_matcher = LavinMQ::IPMatcher.parse("2001:db8::1/128") + cidr_matcher.matches?("2001:db8::1").should be_true + end + end + + describe "IPv4 CIDR matching" do + it "matches IP in /24 range" do + matcher = LavinMQ::IPMatcher.parse("192.168.1.0/24") + matcher.matches?("192.168.1.0").should be_true + matcher.matches?("192.168.1.1").should be_true + matcher.matches?("192.168.1.50").should be_true + matcher.matches?("192.168.1.255").should be_true + end + + it "doesn't match IP outside /24 range" do + matcher = LavinMQ::IPMatcher.parse("192.168.1.0/24") + matcher.matches?("192.168.0.255").should be_false + matcher.matches?("192.168.2.0").should be_false + matcher.matches?("192.169.1.1").should be_false + matcher.matches?("10.0.0.1").should be_false + end + + it "matches IP in /16 range" do + matcher = LavinMQ::IPMatcher.parse("192.168.0.0/16") + matcher.matches?("192.168.0.0").should be_true + matcher.matches?("192.168.1.1").should be_true + matcher.matches?("192.168.255.255").should be_true + end + + it "doesn't match IP outside /16 range" do + matcher = LavinMQ::IPMatcher.parse("192.168.0.0/16") + matcher.matches?("192.167.255.255").should be_false + matcher.matches?("192.169.0.0").should be_false + matcher.matches?("10.0.0.1").should be_false + end + + it "matches IP in /8 range" do + matcher = LavinMQ::IPMatcher.parse("10.0.0.0/8") + matcher.matches?("10.0.0.0").should be_true + matcher.matches?("10.1.2.3").should be_true + matcher.matches?("10.255.255.255").should be_true + end + + it "doesn't match IP outside /8 range" do + matcher = LavinMQ::IPMatcher.parse("10.0.0.0/8") + matcher.matches?("9.255.255.255").should be_false + matcher.matches?("11.0.0.0").should be_false + matcher.matches?("192.168.1.1").should be_false + end + + it "matches IP in /32 range (single host)" do + matcher = LavinMQ::IPMatcher.parse("192.168.1.100/32") + matcher.matches?("192.168.1.100").should be_true + matcher.matches?("192.168.1.99").should be_false + matcher.matches?("192.168.1.101").should be_false + end + + it "matches IP in /0 range (all IPs)" do + matcher = LavinMQ::IPMatcher.parse("0.0.0.0/0") + matcher.matches?("0.0.0.0").should be_true + matcher.matches?("192.168.1.1").should be_true + matcher.matches?("255.255.255.255").should be_true + end + + it "handles /25 CIDR" do + matcher = LavinMQ::IPMatcher.parse("192.168.1.0/25") + matcher.matches?("192.168.1.0").should be_true + matcher.matches?("192.168.1.127").should be_true + matcher.matches?("192.168.1.128").should be_false + matcher.matches?("192.168.1.255").should be_false + end + + it "handles /12 CIDR (AWS VPC default)" do + matcher = LavinMQ::IPMatcher.parse("172.16.0.0/12") + matcher.matches?("172.16.0.0").should be_true + matcher.matches?("172.31.255.255").should be_true + matcher.matches?("172.15.255.255").should be_false + matcher.matches?("172.32.0.0").should be_false + end + + it "normalizes network address" do + # Even if network address isn't properly masked, it should work + matcher = LavinMQ::IPMatcher.parse("192.168.1.50/24") + matcher.matches?("192.168.1.1").should be_true + matcher.matches?("192.168.1.255").should be_true + end + end + + describe "IPv6 CIDR matching" do + it "matches IP in /64 range" do + matcher = LavinMQ::IPMatcher.parse("2001:db8::/64") + matcher.matches?("2001:db8::1").should be_true + matcher.matches?("2001:db8::ffff").should be_true + matcher.matches?("2001:db8:0:0:ffff:ffff:ffff:ffff").should be_true + end + + it "doesn't match IP outside /64 range" do + matcher = LavinMQ::IPMatcher.parse("2001:db8::/64") + matcher.matches?("2001:db8:0:1::1").should be_false + matcher.matches?("2001:db9::1").should be_false + matcher.matches?("2001:db7:ffff:ffff:ffff:ffff:ffff:ffff").should be_false + end + + it "matches IP in /32 range" do + matcher = LavinMQ::IPMatcher.parse("2001:db8::/32") + matcher.matches?("2001:db8::1").should be_true + matcher.matches?("2001:db8:ffff:ffff:ffff:ffff:ffff:ffff").should be_true + end + + it "doesn't match IP outside /32 range" do + matcher = LavinMQ::IPMatcher.parse("2001:db8::/32") + matcher.matches?("2001:db7:ffff:ffff:ffff:ffff:ffff:ffff").should be_false + matcher.matches?("2001:db9::1").should be_false + end + + it "matches IP in /128 range (single host)" do + matcher = LavinMQ::IPMatcher.parse("2001:db8::1/128") + matcher.matches?("2001:db8::1").should be_true + matcher.matches?("2001:db8::2").should be_false + end + + it "matches IP in /0 range (all IPs)" do + matcher = LavinMQ::IPMatcher.parse("::/0") + matcher.matches?("::").should be_true + matcher.matches?("2001:db8::1").should be_true + matcher.matches?("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff").should be_true + end + + it "handles link-local /10 range" do + matcher = LavinMQ::IPMatcher.parse("fe80::/10") + matcher.matches?("fe80::1").should be_true + matcher.matches?("fe80::ffff:ffff:ffff:ffff").should be_true + matcher.matches?("febf:ffff:ffff:ffff:ffff:ffff:ffff:ffff").should be_true + matcher.matches?("fec0::1").should be_false + end + end + + describe "error handling" do + it "raises on invalid IPv4 address" do + expect_raises(ArgumentError, /Invalid IP address/) do + LavinMQ::IPMatcher.parse("256.0.0.1") + end + + expect_raises(ArgumentError, /Invalid IP address/) do + LavinMQ::IPMatcher.parse("not-an-ip") + end + end + + it "raises on invalid IPv6 address" do + expect_raises(ArgumentError, /Invalid IP address/) do + LavinMQ::IPMatcher.parse("gggg::1") + end + end + + it "raises on invalid CIDR notation" do + expect_raises(ArgumentError, /Invalid CIDR/) do + LavinMQ::IPMatcher.parse("192.168.1.0//24") + end + + expect_raises(ArgumentError, /Invalid CIDR/) do + LavinMQ::IPMatcher.parse("192.168.1.0/") + end + end + + it "raises on invalid CIDR prefix" do + expect_raises(ArgumentError, /Invalid CIDR prefix/) do + LavinMQ::IPMatcher.parse("192.168.1.0/abc") + end + end + + it "raises on IPv4 prefix > 32" do + expect_raises(ArgumentError, /IPv4 prefix must be 0-32/) do + LavinMQ::IPMatcher.parse("192.168.1.0/33") + end + + expect_raises(ArgumentError, /IPv4 prefix must be 0-32/) do + LavinMQ::IPMatcher.parse("10.0.0.0/255") + end + end + + it "raises on IPv6 prefix > 128" do + expect_raises(ArgumentError, /IPv6 prefix must be 0-128/) do + LavinMQ::IPMatcher.parse("2001:db8::/129") + end + + expect_raises(ArgumentError, /IPv6 prefix must be 0-128/) do + LavinMQ::IPMatcher.parse("::1/255") + end + end + + it "raises on CIDR with invalid IP" do + expect_raises(ArgumentError, /Invalid IP address in CIDR/) do + LavinMQ::IPMatcher.parse("not-an-ip/24") + end + end + end + + describe "edge cases" do + it "handles whitespace in IP" do + matcher = LavinMQ::IPMatcher.parse(" 192.168.1.1 ") + matcher.matches?("192.168.1.1").should be_true + end + + it "handles whitespace in CIDR" do + matcher = LavinMQ::IPMatcher.parse(" 192.168.1.0 / 24 ") + matcher.matches?("192.168.1.50").should be_true + end + + it "doesn't match IPv6 address against IPv4 CIDR" do + matcher = LavinMQ::IPMatcher.parse("192.168.1.0/24") + matcher.matches?("2001:db8::1").should be_false + end + + it "doesn't match IPv4 address against IPv6 CIDR" do + matcher = LavinMQ::IPMatcher.parse("2001:db8::/32") + matcher.matches?("192.168.1.1").should be_false + end + + it "CIDR notation can match different IPv6 representations" do + # CIDR notation uses byte comparison and can match different representations + matcher = LavinMQ::IPMatcher.parse("2001:db8::1/128") + matcher.matches?("2001:0db8:0000:0000:0000:0000:0000:0001").should be_true + matcher.matches?("2001:db8::1").should be_true + end + end + end +end diff --git a/spec/proxy_protocol_spec.cr b/spec/proxy_protocol_spec.cr index 03d2dc1ae6..907cacd933 100644 --- a/spec/proxy_protocol_spec.cr +++ b/spec/proxy_protocol_spec.cr @@ -1,4 +1,4 @@ -require "spec" +require "./spec_helper" require "../src/lavinmq/proxy_protocol" describe "ProxyProtocol" do @@ -87,4 +87,88 @@ describe "ProxyProtocol" do ).to_slice end end + + describe "trusted sources" do + it "parses individual IPv4 addresses" do + config = LavinMQ::Config.new + config.proxy_protocol_trusted_sources = config.parse_trusted_sources("192.168.1.1, 10.0.0.1") + + config.proxy_protocol_trusted_sources.size.should eq 2 + config.proxy_protocol_trusted_sources[0].matches?("192.168.1.1").should be_true + config.proxy_protocol_trusted_sources[0].matches?("192.168.1.2").should be_false + end + + it "parses IPv4 CIDR notation" do + config = LavinMQ::Config.new + config.proxy_protocol_trusted_sources = config.parse_trusted_sources("192.168.0.0/24") + + config.proxy_protocol_trusted_sources.size.should eq 1 + config.proxy_protocol_trusted_sources[0].matches?("192.168.0.1").should be_true + config.proxy_protocol_trusted_sources[0].matches?("192.168.0.255").should be_true + config.proxy_protocol_trusted_sources[0].matches?("192.168.1.1").should be_false + end + + it "parses IPv6 CIDR notation" do + config = LavinMQ::Config.new + config.proxy_protocol_trusted_sources = config.parse_trusted_sources("2001:db8::/32") + + config.proxy_protocol_trusted_sources.size.should eq 1 + config.proxy_protocol_trusted_sources[0].matches?("2001:db8::1").should be_true + config.proxy_protocol_trusted_sources[0].matches?("2001:db8:ffff:ffff:ffff:ffff:ffff:ffff").should be_true + config.proxy_protocol_trusted_sources[0].matches?("2001:db9::1").should be_false + end + + it "parses mixed IPs and CIDR notation" do + config = LavinMQ::Config.new + config.proxy_protocol_trusted_sources = config.parse_trusted_sources( + "10.0.0.1, 192.168.0.0/24, 2001:db8::/32, ::1") + + config.proxy_protocol_trusted_sources.size.should eq 4 + + # Exact IPv4 + config.proxy_protocol_trusted_sources[0].matches?("10.0.0.1").should be_true + config.proxy_protocol_trusted_sources[0].matches?("10.0.0.2").should be_false + + # IPv4 CIDR + config.proxy_protocol_trusted_sources[1].matches?("192.168.0.50").should be_true + config.proxy_protocol_trusted_sources[1].matches?("192.168.1.50").should be_false + + # IPv6 CIDR + config.proxy_protocol_trusted_sources[2].matches?("2001:db8::100").should be_true + config.proxy_protocol_trusted_sources[2].matches?("2001:db9::1").should be_false + + # Exact IPv6 + config.proxy_protocol_trusted_sources[3].matches?("::1").should be_true + config.proxy_protocol_trusted_sources[3].matches?("::2").should be_false + end + + it "handles invalid entries gracefully" do + config = LavinMQ::Config.new + # This should print warnings to STDERR but not fail + config.proxy_protocol_trusted_sources = config.parse_trusted_sources( + "10.0.0.1, invalid-ip, 192.168.0.0/24, 300.0.0.1") + + # Only valid entries should be parsed + config.proxy_protocol_trusted_sources.size.should eq 2 + config.proxy_protocol_trusted_sources[0].matches?("10.0.0.1").should be_true + config.proxy_protocol_trusted_sources[1].matches?("192.168.0.50").should be_true + end + + it "handles whitespace correctly" do + config = LavinMQ::Config.new + config.proxy_protocol_trusted_sources = config.parse_trusted_sources( + " 10.0.0.1 , 192.168.0.0/24 ") + + config.proxy_protocol_trusted_sources.size.should eq 2 + config.proxy_protocol_trusted_sources[0].matches?("10.0.0.1").should be_true + config.proxy_protocol_trusted_sources[1].matches?("192.168.0.50").should be_true + end + + it "returns empty array for empty config" do + config = LavinMQ::Config.new + config.proxy_protocol_trusted_sources = config.parse_trusted_sources("") + + config.proxy_protocol_trusted_sources.size.should eq 0 + end + end end diff --git a/src/lavinmq/config.cr b/src/lavinmq/config.cr index e2d47aaa6d..b90fa8c0a7 100644 --- a/src/lavinmq/config.cr +++ b/src/lavinmq/config.cr @@ -7,6 +7,7 @@ require "./log_formatter" require "./in_memory_backend" require "./auth/password" require "./sni_config" +require "./ip_matcher" module LavinMQ class Config @@ -26,8 +27,9 @@ module LavinMQ property mqtts_port = 8883 property mqtt_unix_path = "" property unix_path = "" - property unix_proxy_protocol = 1_u8 # PROXY protocol version on unix domain socket connections - property tcp_proxy_protocol = 0_u8 # PROXY protocol version on amqp tcp connections + property unix_proxy_protocol = 1_u8 # PROXY protocol version on unix domain socket connections + property tcp_proxy_protocol = 0_u8 # PROXY protocol version on amqp tcp connections + property proxy_protocol_trusted_sources = Array(IPMatcher).new # Comma-separated IPs/CIDRs trusted for PROXY protocol property tls_cert_path = "" property tls_key_path = "" property tls_ciphers = "" @@ -371,22 +373,23 @@ module LavinMQ private def parse_amqp(settings) settings.each do |config, v| case config - when "bind" then @amqp_bind = v - when "port" then @amqp_port = v.to_i32 - when "tls_port" then @amqps_port = v.to_i32 - when "tls_cert" then @tls_cert_path = v # backward compatibility - when "tls_key" then @tls_key_path = v # backward compatibility - when "unix_path" then @unix_path = v - when "heartbeat" then @heartbeat = v.to_u16 - when "frame_max" then @frame_max = v.to_u32 - when "channel_max" then @channel_max = v.to_u16 - when "max_message_size" then @max_message_size = v.to_i32 - when "unix_proxy_protocol" then @unix_proxy_protocol = true?(v) ? 1u8 : v.to_u8? || 0u8 - when "tcp_proxy_protocol" then @tcp_proxy_protocol = true?(v) ? 1u8 : v.to_u8? || 0u8 - when "set_timestamp" then @set_timestamp = true?(v) - when "consumer_timeout" then @consumer_timeout = v.to_u64 - when "default_consumer_prefetch" then @default_consumer_prefetch = v.to_u16 - when "max_consumers_per_channel" then @max_consumers_per_channel = v.to_i + when "bind" then @amqp_bind = v + when "port" then @amqp_port = v.to_i32 + when "tls_port" then @amqps_port = v.to_i32 + when "tls_cert" then @tls_cert_path = v # backward compatibility + when "tls_key" then @tls_key_path = v # backward compatibility + when "unix_path" then @unix_path = v + when "heartbeat" then @heartbeat = v.to_u16 + when "frame_max" then @frame_max = v.to_u32 + when "channel_max" then @channel_max = v.to_u16 + when "max_message_size" then @max_message_size = v.to_i32 + when "unix_proxy_protocol" then @unix_proxy_protocol = true?(v) ? 1u8 : v.to_u8? || 0u8 + when "tcp_proxy_protocol" then @tcp_proxy_protocol = true?(v) ? 1u8 : v.to_u8? || 0u8 + when "proxy_protocol_trusted_sources" then @proxy_protocol_trusted_sources = parse_trusted_sources(v) + when "set_timestamp" then @set_timestamp = true?(v) + when "consumer_timeout" then @consumer_timeout = v.to_u64 + when "default_consumer_prefetch" then @default_consumer_prefetch = v.to_u16 + when "max_consumers_per_channel" then @max_consumers_per_channel = v.to_i else STDERR.puts "WARNING: Unrecognized configuration 'amqp/#{config}'" end @@ -484,6 +487,20 @@ module LavinMQ end end + def parse_trusted_sources(config_value : String) : Array(IPMatcher) + config_value.split(',') + .map(&.strip) + .reject(&.empty?) + .compact_map do |source| + begin + IPMatcher.parse(source) + rescue ex : Socket::Error | ArgumentError + STDERR.puts "WARNING: Invalid IP/CIDR in proxy_protocol_trusted_sources: #{source} - #{ex.message}" + nil + end + end + end + private def true?(str : String?) {"true", "yes", "y", "1"}.includes? str end diff --git a/src/lavinmq/ip_matcher.cr b/src/lavinmq/ip_matcher.cr new file mode 100644 index 0000000000..0cf3d2fdd3 --- /dev/null +++ b/src/lavinmq/ip_matcher.cr @@ -0,0 +1,158 @@ +require "socket" + +module LavinMQ + # Matches IP addresses against either exact IPs or CIDR ranges + struct IPMatcher + enum Type + ExactIPv4 + ExactIPv6 + CIDRv4 + CIDRv6 + end + + @type : Type + @address_string : String + @network : Bytes? # Only used for CIDR + @mask : Bytes? # Only used for CIDR + + def initialize(@type : Type, @address_string : String, @network : Bytes? = nil, @mask : Bytes? = nil) + end + + # Parse from config string: "192.168.0.0/24" or "10.0.0.1" + def self.parse(source : String) : IPMatcher + if source.includes?('/') + parse_cidr(source) + else + parse_exact_ip(source) + end + end + + # Check if an IP address matches this matcher + def matches?(address : String) : Bool + case @type + when Type::ExactIPv4, Type::ExactIPv6 + @address_string == address + when Type::CIDRv4 + network = @network + mask = @mask + return false unless network && mask + if target = ip_to_bytes_v4(address) + matches_cidr?(target, network, mask) + else + false + end + when Type::CIDRv6 + network = @network + mask = @mask + return false unless network && mask + if target = ip_to_bytes_v6(address) + matches_cidr?(target, network, mask) + else + false + end + else + false + end + end + + private def self.parse_cidr(source : String) : IPMatcher + parts = source.split('/', 2) + raise ArgumentError.new("Invalid CIDR notation: #{source}") if parts.size != 2 + + ip_str = parts[0].strip + prefix_str = parts[1].strip + prefix = prefix_str.to_u8? + raise ArgumentError.new("Invalid CIDR prefix: #{prefix_str}") unless prefix + + # Try IPv4 first + if fields = Socket::IPAddress.parse_v4_fields?(ip_str) + raise ArgumentError.new("IPv4 prefix must be 0-32, got #{prefix}") if prefix > 32 + network = Bytes.new(4) { |i| fields[i] } + mask = calculate_mask(prefix, 4) + # Apply mask to network address to normalize it + network.size.times { |i| network[i] &= mask[i] } + return new(Type::CIDRv4, source, network, mask) + end + + # Try IPv6 + if fields = Socket::IPAddress.parse_v6_fields?(ip_str) + raise ArgumentError.new("IPv6 prefix must be 0-128, got #{prefix}") if prefix > 128 + network = v6_fields_to_bytes(fields) + mask = calculate_mask(prefix, 16) + # Apply mask to network address to normalize it + network.size.times { |i| network[i] &= mask[i] } + return new(Type::CIDRv6, source, network, mask) + end + + raise ArgumentError.new("Invalid IP address in CIDR: #{ip_str}") + end + + private def self.parse_exact_ip(source : String) : IPMatcher + source = source.strip + + # Try IPv4 - just validate + if Socket::IPAddress.parse_v4_fields?(source) + return new(Type::ExactIPv4, source) + end + + # Try IPv6 - just validate + if Socket::IPAddress.parse_v6_fields?(source) + return new(Type::ExactIPv6, source) + end + + raise ArgumentError.new("Invalid IP address: #{source}") + end + + # Calculate netmask from prefix length + private def self.calculate_mask(prefix : UInt8, total_bytes : Int32) : Bytes + mask = Bytes.new(total_bytes, 0_u8) + full_bytes = prefix // 8 + remaining_bits = prefix % 8 + + # Set full bytes to 0xFF + full_bytes.times { |i| mask[i] = 0xFF_u8 } + + # Set partial byte if any remaining bits + if remaining_bits > 0 && full_bytes < total_bytes + mask[full_bytes] = (0xFF_u8 << (8 - remaining_bits)) + end + + mask + end + + # Convert IPv4 address string to bytes + private def ip_to_bytes_v4(address : String) : Bytes? + if fields = Socket::IPAddress.parse_v4_fields?(address) + Bytes.new(4) { |i| fields[i] } + end + end + + # Convert IPv6 address string to bytes + private def ip_to_bytes_v6(address : String) : Bytes? + if fields = Socket::IPAddress.parse_v6_fields?(address) + IPMatcher.v6_fields_to_bytes(fields) + end + end + + # Convert IPv6 fields array to bytes (shared helper) + protected def self.v6_fields_to_bytes(fields : StaticArray(UInt16, 8)) : Bytes + bytes = Bytes.new(16) + fields.each_with_index do |field, i| + bytes[i * 2] = (field >> 8).to_u8 + bytes[i * 2 + 1] = (field & 0xFF).to_u8 + end + bytes + end + + # Check if target IP matches CIDR range using bitwise AND + private def matches_cidr?(target : Bytes, network : Bytes, mask : Bytes) : Bool + return false unless target.size == network.size == mask.size + + target.size.times do |i| + return false if (target[i] & mask[i]) != (network[i] & mask[i]) + end + + true + end + end +end diff --git a/src/lavinmq/server.cr b/src/lavinmq/server.cr index 4e07e7c963..34b22179d6 100644 --- a/src/lavinmq/server.cr +++ b/src/lavinmq/server.cr @@ -48,6 +48,9 @@ module LavinMQ @mqtt_brokers = MQTT::Brokers.new(@vhosts, @replicator) @parameters = ParameterStore(Parameter).new(@data_dir, "parameters.json", @replicator) @authenticator = Auth::Chain.create(@users) + if @config.tcp_proxy_protocol > 0 && @config.proxy_protocol_trusted_sources.empty? + Log.warn { "PROXY protocol enabled without trusted sources configured - accepting from all sources" } + end @connection_factories = { Protocol::AMQP => AMQP::ConnectionFactory.new(authenticator, @vhosts), Protocol::MQTT => MQTT::ConnectionFactory.new(authenticator, @mqtt_brokers, @config), @@ -135,19 +138,22 @@ module LavinMQ private def extract_conn_info(client) : ConnectionInfo remote_address = client.remote_address case @config.tcp_proxy_protocol - when 1 then ProxyProtocol::V1.parse(client) - when 2 then ProxyProtocol::V2.parse(client) + when 1, 2 + if trusted_proxy_source?(remote_address.address) + parse_proxy_protocol(client, @config.tcp_proxy_protocol) + else + Log.warn { "PROXY protocol from untrusted source #{remote_address}, ignoring header" } + ConnectionInfo.new(remote_address, client.local_address) + end else - # Allow proxy connection from followers + # Accept PROXY protocol from verified cluster followers if @config.clustering? && client.peek[0, 5]? == "PROXY".to_slice && all_followers.any? { |f| f.remote_address.address == remote_address.address } - # Expect PROXY protocol header if remote address is a follower ProxyProtocol::V1.parse(client) elsif @config.clustering? && client.peek[0, 8]? == ProxyProtocol::V2::Signature.to_slice[0, 8] && all_followers.any? { |f| f.remote_address.address == remote_address.address } - # Expect PROXY protocol header if remote address is a follower ProxyProtocol::V2.parse(client) else ConnectionInfo.new(remote_address, client.local_address) @@ -155,6 +161,19 @@ module LavinMQ end end + private def trusted_proxy_source?(address : String) : Bool + return true if @config.proxy_protocol_trusted_sources.empty? + @config.proxy_protocol_trusted_sources.any?(&.matches?(address)) + end + + private def parse_proxy_protocol(client, version : UInt8) : ConnectionInfo + case version + when 1 then ProxyProtocol::V1.parse(client) + when 2 then ProxyProtocol::V2.parse(client) + else raise "Invalid proxy protocol version: #{version}" + end + end + def listen(s : UNIXServer, protocol : Protocol) @listeners[s] = protocol Log.info { "Listening for #{protocol} on #{s.local_address}" }