diff --git a/util/httputils/httputils.go b/util/httputils/httputils.go index 0c030db..554172f 100644 --- a/util/httputils/httputils.go +++ b/util/httputils/httputils.go @@ -281,27 +281,24 @@ func ErrorMsg(err error) string { func GetAddrPort(urlStr string) (string, int, error) { parts, err := url.Parse(urlStr) if err != nil { - return "", 0, err + return "", 0, errors.Wrapf(err, "url.Parse %s", urlStr) } - host := parts.Host - hostAddr, hostPortStr, err := net.SplitHostPort(host) - if err == nil { - hostPort, err := strconv.ParseInt(hostPortStr, 10, 32) - if err != nil { - return "", 0, errors.Wrapf(err, "strconv.ParseInt port string %s", hostPortStr) - } else { - return hostAddr, int(hostPort), nil - } - } else { + portStr := parts.Port() + if len(portStr) == 0 { switch parts.Scheme { case "http": - return parts.Host, 80, nil + return parts.Hostname(), 80, nil case "https": - return parts.Host, 443, nil + return parts.Hostname(), 443, nil default: - return "", 0, errors.Wrapf(errors.ErrInvalidFormat, "Unknown schema %s", parts.Scheme) + return "", 0, errors.Errorf("Unknown schema %s", parts.Scheme) } } + port, err := strconv.Atoi(portStr) + if err != nil { + return "", 0, errors.Wrapf(err, "strconv.Atoi port string %s", portStr) + } + return parts.Hostname(), port, nil } func GetTransport(insecure bool) *http.Transport { diff --git a/util/httputils/httputils_test.go b/util/httputils/httputils_test.go index 22b5c27..b14263c 100644 --- a/util/httputils/httputils_test.go +++ b/util/httputils/httputils_test.go @@ -294,15 +294,53 @@ func TestIdleTimeout(t *testing.T) { } } -/*func TestDialTimeout(t *testing.T) { - cli := GetAdaptiveTimeoutClient() - resp, err := cli.Get(fmt.Sprintf("http://192.0.0.1:48481")) - if err == nil { - t.Errorf("Read shoud error") - } else if !err.(*url.Error).Timeout() { - t.Errorf("Read error %s %s, should be url.Error.Timeout", err, reflect.TypeOf(err)) - } else { - t.Logf("Read error %s %s", err, reflect.TypeOf(err)) +type addrPort struct { + url string + host string + port int +} + +var addrPortList = []addrPort{ + { + url: "http://192.0.0.1:48481", + host: "192.0.0.1", + port: 48481, + }, + { + url: "https://192.0.0.1", + host: "192.0.0.1", + port: 443, + }, + { + url: "https://[fc00::300:100]/api/s/identity/v3/auth/tokens", + host: "fc00::300:100", + port: 443, + }, + { + url: "http://[fc00::300:100]/api/s/identity/v3/auth/tokens", + host: "fc00::300:100", + port: 80, + }, + { + url: "https://[fc00::300:100]:3000/api/s/identity/v3/auth/tokens", + host: "fc00::300:100", + port: 3000, + }, + { + url: "https://192.0.0.1:48481", + host: "192.0.0.1", + port: 48481, + }, +} + +func TestAddrPort(t *testing.T) { + for _, addrPort := range addrPortList { + host, port, err := GetAddrPort(addrPort.url) + if err != nil { + t.Errorf("GetAddrPort error %s", err) + } + if host != addrPort.host || port != addrPort.port { + t.Errorf("GetAddrPort error %s, %s => %s, %d => %d", addrPort.url, host, addrPort.host, port, addrPort.port) + } } - CloseResponse(resp) -}*/ +}