diff --git a/http.go b/http.go index e0d391a5..e83dfd55 100644 --- a/http.go +++ b/http.go @@ -19,6 +19,7 @@ package groupcache import ( "bytes" "context" + "errors" "fmt" "io" "net/http" @@ -115,16 +116,45 @@ func NewHTTPPoolOpts(self string, o *HTTPPoolOptions) *HTTPPool { // Set updates the pool's list of peers. // Each peer value should be a valid base URL, -// for example "http://example.net:8000". -func (p *HTTPPool) Set(peers ...string) { +// for example "http://example.net:8000". Note +// that the scheme ("http://" or "https://") is +// required. +func (p *HTTPPool) Set(peers ...string) error { p.mu.Lock() defer p.mu.Unlock() - p.peers = consistenthash.New(p.opts.Replicas, p.opts.HashFn) - p.peers.Add(peers...) - p.httpGetters = make(map[string]*httpGetter, len(peers)) + for _, peer := range peers { + // Make sure peer address is valid before using it. Address + // must be a valid base URL. + u, err := url.Parse(peer) + if err != nil { + return err + } + + // Check that scheme is provided per the func-level comment. If + // scheme is missing, host will not be able to communicate with + // peers. Not using strings.Contains() here since peer address + // could be incorrectly typed, such as hhttp or httpss. + if u.Scheme != "http" && u.Scheme != "https" { + return errors.New("peer address not using correct scheme, must be http or https") + } + + // This was added to handle an address such as http:///example.com:8000 + // where the address will be parsed as a path with blank host. + if u.Host == "" { + return errors.New("peer address missing host, possibly extra slashes in address") + } + + // TODO: check if port is valid. Use net.SplitHostPort? + p.httpGetters[peer] = &httpGetter{transport: p.Transport, baseURL: peer + p.opts.BasePath} } + + p.peers = consistenthash.New(p.opts.Replicas, p.opts.HashFn) + p.peers.Add(peers...) + p.httpGetters = make(map[string]*httpGetter, len(peers)) + + return nil } func (p *HTTPPool) PickPeer(key string) (ProtoGetter, bool) { diff --git a/http_test.go b/http_test.go index 132c1173..295db647 100644 --- a/http_test.go +++ b/http_test.go @@ -165,3 +165,51 @@ func awaitAddrReady(t *testing.T, addr string, wg *sync.WaitGroup) { time.Sleep(delay) } } + +func TestSet(t *testing.T) { + // list of peer addresses to test with, and whether or not we expect + // peer to be valid. + // TODO: add IPv6 addresses. + peers := map[string]bool{ + //addres.......................valid + "http://10.0.0.1:8000": true, + "https://example.com:8001": true, + "http://sub.example.com:8002": true, + "https://localhost:8003": true, + + "10.0.0.1:8100": false, + "example.com:8101": false, + "sub.example.com:8102": false, + "localhost:8103": false, + "http:////example.com:8104": false, + "//example.com:8105": false, + "httpss//example.com:8106": false, + "hhttp//example.com:8107": false, + "htxtp//example.com:8108": false, + "http//example:8109": false, + "http//example/path/:8110": false, + "/": false, + "": false, + ":8111": false, + ":http://example.com": false, + } + + // create pool to use for testing, using a known + // good/valid address + pool := NewHTTPPool("http://localhost:8080") + + // try setting peers + for addr, valid := range peers { + err := pool.Set(addr) + if valid && err != nil { + t.Fatal("Peer address NOT valid but should be: " + addr) + return + } + if !valid && err == nil { + t.Fatal("Peer address valid but should NOT be: " + addr) + return + } + } + + return +}