Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions server/public_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2766,6 +2766,11 @@ func Test_sanitizePagingParams(t *testing.T) {
{"oversized page size", 1, maxWebsocketBlockPageSize + 1, txsInAPI, maxWebsocketBlockPageSize, 1, maxWebsocketBlockPageSize},
{"negative values", -1, -1, txsInAPI, maxWebsocketBlockPageSize, 0, txsInAPI},
{"safe offset clamp", maxPageNumber, maxPageNumber, maxPageNumber, maxPageNumber, maxSafePagingOffset / maxPageNumber, maxPageNumber},
// WS getAccountInfo arguments: default 25, cap at txsInAPI.
{"ws getAccountInfo default", 0, 0, txsOnPage, txsInAPI, 0, txsOnPage},
{"ws getAccountInfo within limit", 1, 100, txsOnPage, txsInAPI, 1, 100},
{"ws getAccountInfo caps at txsInAPI", 1, txsInAPI + 1, txsOnPage, txsInAPI, 1, txsInAPI},
{"ws getAccountInfo negative defaults", 0, -5, txsOnPage, txsInAPI, 0, txsOnPage},
}

for _, tt := range tests {
Expand All @@ -2779,3 +2784,26 @@ func Test_sanitizePagingParams(t *testing.T) {
})
}
}

func Test_validateIntValue_gapClamp(t *testing.T) {
// Mirrors the WS getAccountInfo gap clamp: validateIntValue(req.Gap, 0, 0, maxGapValue).
tests := []struct {
name string
val int
want int
}{
{"unset passes through as 0", 0, 0},
{"suite default 20 passes through", 20, 20},
{"negative defaults to 0", -1, 0},
{"caps at maxGapValue", maxGapValue + 1, maxGapValue},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := validateIntValue(tt.val, 0, 0, maxGapValue)
if got != tt.want {
t.Errorf("validateIntValue(%d, 0, 0, %d) = %d, want %d",
tt.val, maxGapValue, got, tt.want)
}
})
}
}
158 changes: 149 additions & 9 deletions server/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package server
import (
"encoding/json"
"math/big"
"net"
"net/http"
"net/netip"
"net/url"
"os"
"runtime/debug"
Expand All @@ -29,6 +31,14 @@ const defaultTimeout = 60 * time.Second
const unknownMethodLabel = "unknown"
const maxWebsocketMessageBytes int64 = 4 * 1024 * 1024
const maxWebsocketPendingRequests = 48
const maxWebsocketConnectionAttemptsPerIP = 64
const maxWebsocketConnectionsPerIP = 128
const maxWebsocketEstimateFeeBlocks = 32
const maxWebsocketSubscribeAddresses = 1000
const maxWebsocketSubscribeAddressesWithNewBlockTxs = 100
const websocketConnectionAttemptWindow = time.Minute
const websocketConnectionLimiterTTL = 10 * time.Minute
const websocketConnectionLimiterCleanupInterval = time.Minute
const websocketLogPreviewBytes = 256

// allRates is a special "currency" parameter that means all available currencies
Expand Down Expand Up @@ -90,6 +100,19 @@ type WebsocketServer struct {
fiatRatesSubscriptionsLock sync.Mutex
allowedOrigins map[string]struct{}
allowedRpcCallTo map[string]struct{}
websocketLimiter *websocketConnectionLimiter
}

type websocketClientLimit struct {
active int
attempts []time.Time
lastSeen time.Time
}

type websocketConnectionLimiter struct {
mux sync.Mutex
clients map[string]*websocketClientLimit
lastCleanup time.Time
}

// NewWebsocketServer creates new websocket interface to blockbook and returns its handle
Expand Down Expand Up @@ -118,6 +141,7 @@ func NewWebsocketServer(db *db.RocksDB, chain bchain.BlockChain, mempool bchain.
addressSubscriptions: make(map[string]map[*websocketChannel]*addressDetails),
fiatRatesSubscriptions: make(map[string]map[*websocketChannel]string),
fiatRatesTokenSubscriptions: make(map[*websocketChannel][]string),
websocketLimiter: newWebsocketConnectionLimiter(),
}
s.upgrader = &websocket.Upgrader{
ReadBufferSize: 1024 * 32,
Expand Down Expand Up @@ -191,16 +215,106 @@ func normalizeOrigin(origin string) (string, bool) {
return strings.ToLower(u.Scheme) + "://" + strings.ToLower(u.Host), true
}

func newWebsocketConnectionLimiter() *websocketConnectionLimiter {
return &websocketConnectionLimiter{
clients: make(map[string]*websocketClientLimit),
}
}

func (l *websocketConnectionLimiter) accept(ip string, now time.Time) (bool, string) {
l.mux.Lock()
defer l.mux.Unlock()

l.cleanupLocked(now)
client := l.clients[ip]
if client == nil {
client = &websocketClientLimit{}
l.clients[ip] = client
}
client.lastSeen = now
client.trimAttempts(now)

if client.active >= maxWebsocketConnectionsPerIP {
return false, "connection_limit"
}
if len(client.attempts) >= maxWebsocketConnectionAttemptsPerIP {
return false, "connection_attempt_limit"
}

client.attempts = append(client.attempts, now)
client.active++
return true, ""
}

func (l *websocketConnectionLimiter) release(ip string, now time.Time) {
l.mux.Lock()
defer l.mux.Unlock()

client := l.clients[ip]
if client == nil {
return
}
if client.active > 0 {
client.active--
}
client.lastSeen = now
l.cleanupLocked(now)
}

func (l *websocketConnectionLimiter) cleanupLocked(now time.Time) {
if !l.lastCleanup.IsZero() && now.Sub(l.lastCleanup) < websocketConnectionLimiterCleanupInterval {
return
}
l.lastCleanup = now
for ip, client := range l.clients {
client.trimAttempts(now)
if client.active == 0 && now.Sub(client.lastSeen) > websocketConnectionLimiterTTL {
delete(l.clients, ip)
}
}
}
Comment on lines +320 to +335
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

websocketConnectionLimiter cleanup only runs on accept()/release() calls. If the server experiences a burst of many unique IPs and then becomes idle (no new connections), stale entries older than websocketConnectionLimiterTTL will never be evicted, so the clients map can retain memory indefinitely. Consider running cleanup on a background ticker, or otherwise ensuring eviction happens without requiring subsequent connection activity.

Copilot uses AI. Check for mistakes.

func (client *websocketClientLimit) trimAttempts(now time.Time) {
cutoff := now.Add(-websocketConnectionAttemptWindow)
i := 0
for i < len(client.attempts) && client.attempts[i].Before(cutoff) {
i++
}
if i > 0 {
copy(client.attempts, client.attempts[i:])
client.attempts = client.attempts[:len(client.attempts)-i]
}
}

func getIP(r *http.Request) string {
ip := r.Header.Get("cf-connecting-ip")
if ip != "" {
if ip, ok := parseIP(r.Header.Get("CF-Connecting-IPv6")); ok {
return ip
}
ip = r.Header.Get("X-Real-Ip")
if ip != "" {
if ip, ok := parseIP(r.Header.Get("CF-Connecting-IP")); ok {
return ip
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getIP() trusts CF-Connecting-* headers unconditionally. Because the websocket connection caps are keyed off getIP(), a client can bypass the limiter by spoofing these headers unless the server is guaranteed to sit behind Cloudflare (or another trusted proxy that strips/overwrites them). Consider only honoring these headers when the TCP peer (RemoteAddr) is in a configured set of trusted proxy CIDRs, or gate this behavior behind an explicit "trust proxy headers" config.

Copilot uses AI. Check for mistakes.
}
return r.RemoteAddr

host := r.RemoteAddr
if h, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
host = h
}
if ip, ok := parseIP(host); ok {
return ip
}

return strings.TrimSpace(r.RemoteAddr)
}

func parseIP(value string) (string, bool) {
value = strings.TrimSpace(value)
if value == "" {
return "", false
}
ip, err := netip.ParseAddr(value)
if err != nil {
return "", false
}
return ip.String(), true
}

func getWebsocketPayloadPreview(d []byte) string {
Expand All @@ -216,8 +330,22 @@ func (s *WebsocketServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
http.Error(w, upgradeFailed+ErrorMethodNotAllowed.Error(), http.StatusServiceUnavailable)
return
}
ip := getIP(r)
limited := false
if s.websocketLimiter != nil {
ok, reason := s.websocketLimiter.accept(ip, time.Now())
if !ok {
glog.Warning("Websocket connection rejected, ", ip, ", ", reason)
http.Error(w, "Too many websocket connections", http.StatusTooManyRequests)
return
}
limited = true
}
conn, err := s.upgrader.Upgrade(w, r, nil)
if err != nil {
if limited {
s.websocketLimiter.release(ip, time.Now())
}
http.Error(w, upgradeFailed+err.Error(), http.StatusServiceUnavailable)
return
}
Expand All @@ -227,7 +355,7 @@ func (s *WebsocketServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
conn: conn,
out: make(chan *WsRes, outChannelSize),
pendingRequests: make(chan struct{}, maxWebsocketPendingRequests),
ip: getIP(r),
ip: ip,
requestHeader: r.Header,
alive: true,
}
Expand Down Expand Up @@ -381,6 +509,9 @@ func (s *WebsocketServer) onDisconnect(c *websocketChannel) {
s.unsubscribeNewTransaction(c)
s.unsubscribeAddresses(c)
s.unsubscribeFiatRates(c)
if s.websocketLimiter != nil {
s.websocketLimiter.release(c.ip, time.Now())
}
glog.Info("Client disconnected ", c.id, ", ", c.ip)
s.metrics.WebsocketClients.Dec()
}
Expand Down Expand Up @@ -689,9 +820,8 @@ func (s *WebsocketServer) getAccountInfo(req *WsAccountInfoReq) (res *api.Addres
TokensToReturn: tokensToReturn,
Protocols: req.Protocols,
}
if req.PageSize == 0 {
req.PageSize = txsOnPage
}
req.Page, req.PageSize = sanitizePagingParams(req.Page, req.PageSize, txsOnPage, txsInAPI)
req.Gap = validateIntValue(req.Gap, 0, 0, maxGapValue)
a, err := s.api.GetXpubAddress(req.Descriptor, req.Page, req.PageSize, opt, &filter, req.Gap, strings.ToLower(req.SecondaryCurrency))
if err != nil {
return s.api.GetAddress(req.Descriptor, req.Page, req.PageSize, opt, &filter, strings.ToLower(req.SecondaryCurrency))
Expand Down Expand Up @@ -792,6 +922,9 @@ func (s *WebsocketServer) estimateFee(params []byte) (interface{}, error) {
if err != nil {
return nil, err
}
if len(r.Blocks) > maxWebsocketEstimateFeeBlocks {
return nil, api.NewAPIError("blocks max "+strconv.Itoa(maxWebsocketEstimateFeeBlocks), true)
}
res := make([]WsEstimateFeeRes, len(r.Blocks))
if s.chainParser.GetChainType() == bchain.ChainEthereumType {
gas, err := s.chain.EthereumTypeEstimateGas(r.Specific)
Expand Down Expand Up @@ -1017,6 +1150,13 @@ func (s *WebsocketServer) unmarshalAddresses(params []byte) ([]string, bool, err
if err != nil {
return nil, false, api.NewAPIError("Invalid subscribeAddresses params", true)
}
limit := maxWebsocketSubscribeAddresses
if r.NewBlockTxs {
limit = maxWebsocketSubscribeAddressesWithNewBlockTxs
}
if len(r.Addresses) > limit {
return nil, false, api.NewAPIError("addresses max "+strconv.Itoa(limit), true)
}
rv := make([]string, len(r.Addresses))
for i, a := range r.Addresses {
ad, err := s.chainParser.GetAddrDescFromAddress(a)
Expand Down
Loading
Loading