diff --git a/cmd/devd/devd.go b/cmd/devd/devd.go index 72c82a9..2bb6d0c 100644 --- a/cmd/devd/devd.go +++ b/cmd/devd/devd.go @@ -184,7 +184,7 @@ func main() { hdrs := make(http.Header) if *cors { - hdrs.Set("Access-Control-Allow-Origin", "*") + hdrs.Set("Access-Control-Allow-Credentials", "true") } var servingScheme string @@ -196,9 +196,9 @@ func main() { dd := devd.Devd{ // Shaping - Latency: *latency, - DownKbps: *downKbps, - UpKbps: *upKbps, + Latency: *latency, + DownKbps: *downKbps, + UpKbps: *upKbps, ServingScheme: servingScheme, AddHeaders: &hdrs, @@ -209,6 +209,8 @@ func main() { WatchPaths: *watch, Excludes: *excludes, + Cors: *cors, + Credentials: creds, } diff --git a/httpctx/httpctx.go b/httpctx/httpctx.go index 0cb5df7..9a87a10 100644 --- a/httpctx/httpctx.go +++ b/httpctx/httpctx.go @@ -47,3 +47,14 @@ func StripPrefix(prefix string, h Handler) Handler { } }) } + +func RouteWebsockets(httpHandler, wsHandler Handler) Handler { + return HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) { + upgrade := r.Header.Get("Upgrade") + if upgrade == "websocket" { + wsHandler.ServeHTTPContext(ctx, w, r) + } else { + httpHandler.ServeHTTPContext(ctx, w, r) + } + }) +} diff --git a/responselogger.go b/responselogger.go index 6d12a0b..d9c98e7 100644 --- a/responselogger.go +++ b/responselogger.go @@ -1,7 +1,9 @@ package devd import ( + "bufio" "fmt" + "net" "net/http" "strconv" @@ -15,6 +17,7 @@ import ( type ResponseLogWriter struct { Log termlog.Logger Resp http.ResponseWriter + Hijacker http.Hijacker Flusher http.Flusher Timer *timer.Timer wroteHeader bool @@ -87,3 +90,6 @@ func (rl *ResponseLogWriter) Flush() { rl.Flusher.Flush() } } +func (rl *ResponseLogWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return rl.Hijacker.Hijack() +} diff --git a/route.go b/route.go index 6be330a..0f82c78 100644 --- a/route.go +++ b/route.go @@ -14,6 +14,8 @@ import ( "github.com/cortesi/devd/inject" "github.com/cortesi/devd/reverseproxy" "github.com/cortesi/devd/routespec" + "github.com/cortesi/devd/websocketproxy" + "github.com/gorilla/websocket" ) // Endpoint is the destination of a Route - either on the filesystem or @@ -33,7 +35,20 @@ func (ep forwardEndpoint) Handler(prefix string, templates *template.Template, c TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } rp.FlushInterval = 200 * time.Millisecond - return httpctx.StripPrefix(prefix, rp) + + wsURL := url.URL(ep) + switch wsURL.Scheme { + case "http": + wsURL.Scheme = "ws" + case "https": + wsURL.Scheme = "wss" + } + ws := websocketproxy.NewProxy(&wsURL) + ws.Dialer = &websocket.Dialer{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + + return httpctx.StripPrefix(prefix, httpctx.RouteWebsockets(rp, ws)) } func newForwardEndpoint(path string) (*forwardEndpoint, error) { diff --git a/server.go b/server.go index 9dc9885..ab87fbe 100644 --- a/server.go +++ b/server.go @@ -151,6 +151,9 @@ type Devd struct { WatchPaths []string Excludes []string + // Add Access-Control-Allow-Origin header + Cors bool + // Logging IgnoreLogs []*regexp.Regexp @@ -194,10 +197,22 @@ func (dd *Devd) WrapHandler(log termlog.TermLog, next httpctx.Handler) http.Hand } } } + if dd.Cors { + origin := r.Header.Get("Origin") + if origin == "" { + origin = "*" + } + w.Header().Set("Access-Control-Allow-Origin", origin) + requestHeaders := r.Header.Get("Access-Control-Request-Headers") + if requestHeaders != "" { + w.Header().Set("Access-Control-Allow-Headers", requestHeaders) + } + } flusher, _ := w.(http.Flusher) + hijacker, _ := w.(http.Hijacker) next.ServeHTTPContext( ctx, - &ResponseLogWriter{Log: sublog, Resp: w, Flusher: flusher, Timer: &timr}, + &ResponseLogWriter{Log: sublog, Resp: w, Hijacker: hijacker, Flusher: flusher, Timer: &timr}, r, ) }) diff --git a/websocketproxy/.travis.yml b/websocketproxy/.travis.yml new file mode 100644 index 0000000..8670f00 --- /dev/null +++ b/websocketproxy/.travis.yml @@ -0,0 +1,2 @@ +language: go +go: 1.8 diff --git a/websocketproxy/LICENSE.md b/websocketproxy/LICENSE.md new file mode 100644 index 0000000..f0a2a7c --- /dev/null +++ b/websocketproxy/LICENSE.md @@ -0,0 +1,20 @@ +The MIT License (MIT) + +Copyright (c) 2014 Koding, Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/websocketproxy/README.md b/websocketproxy/README.md new file mode 100644 index 0000000..526bb43 --- /dev/null +++ b/websocketproxy/README.md @@ -0,0 +1,54 @@ +# WebsocketProxy [![GoDoc](https://godoc.org/github.com/koding/websocketproxy?status.svg)](https://godoc.org/github.com/koding/websocketproxy) [![Build Status](https://travis-ci.org/koding/websocketproxy.svg)](https://travis-ci.org/koding/websocketproxy) + +WebsocketProxy is an http.Handler interface build on top of +[gorilla/websocket](https://github.com/gorilla/websocket) that you can plug +into your existing Go webserver to provide WebSocket reverse proxy. + +## Install + +```bash +go get github.com/koding/websocketproxy +``` + +## Example + +Below is a simple server that proxies to the given backend URL + +```go +package main + +import ( + "flag" + "net/http" + "net/url" + + "github.com/koding/websocketproxy" +) + +var ( + flagBackend = flag.String("backend", "", "Backend URL for proxying") +) + +func main() { + u, err := url.Parse(*flagBackend) + if err != nil { + log.Fatalln(err) + } + + err = http.ListenAndServe(":80", websocketproxy.NewProxy(u)) + if err != nil { + log.Fatalln(err) + } +} +``` + +Save it as `proxy.go` and run as: + +```bash +go run proxy.go -backend ws://example.com:3000 +``` + +Now all incoming WebSocket requests coming to this server will be proxied to +`ws://example.com:3000` + + diff --git a/websocketproxy/websocketproxy.go b/websocketproxy/websocketproxy.go new file mode 100644 index 0000000..1331665 --- /dev/null +++ b/websocketproxy/websocketproxy.go @@ -0,0 +1,239 @@ +// Package websocketproxy is a reverse proxy for WebSocket connections. +package websocketproxy + +import ( + "context" + "fmt" + "io" + "log" + "net" + "net/http" + "net/url" + "strings" + + "github.com/gorilla/websocket" +) + +var ( + // DefaultUpgrader specifies the parameters for upgrading an HTTP + // connection to a WebSocket connection. + DefaultUpgrader = &websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + } + + // DefaultDialer is a dialer with all fields set to the default zero values. + DefaultDialer = websocket.DefaultDialer +) + +// WebsocketProxy is an HTTP Handler that takes an incoming WebSocket +// connection and proxies it to another server. +type WebsocketProxy struct { + // Director, if non-nil, is a function that may copy additional request + // headers from the incoming WebSocket connection into the output headers + // which will be forwarded to another server. + Director func(incoming *http.Request, out http.Header) + + // Backend returns the backend URL which the proxy uses to reverse proxy + // the incoming WebSocket connection. Request is the initial incoming and + // unmodified request. + Backend func(*http.Request) *url.URL + + // Upgrader specifies the parameters for upgrading a incoming HTTP + // connection to a WebSocket connection. If nil, DefaultUpgrader is used. + Upgrader *websocket.Upgrader + + // Dialer contains options for connecting to the backend WebSocket server. + // If nil, DefaultDialer is used. + Dialer *websocket.Dialer +} + +// ProxyHandler returns a new http.Handler interface that reverse proxies the +// request to the given target. +func ProxyHandler(target *url.URL) http.Handler { return NewProxy(target) } + +// NewProxy returns a new Websocket reverse proxy that rewrites the +// URL's to the scheme, host and base path provider in target. +func NewProxy(target *url.URL) *WebsocketProxy { + backend := func(r *http.Request) *url.URL { + // Shallow copy + u := *target + u.Fragment = r.URL.Fragment + u.Path = r.URL.Path + u.RawQuery = r.URL.RawQuery + return &u + } + return &WebsocketProxy{Backend: backend} +} + +// ServeHTTP implements the http.Handler that proxies WebSocket connections. +func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + w.ServeHTTPContext(context.Background(), rw, req) +} + +// ServeHTTP implements the http.Handler that proxies WebSocket connections. +func (w *WebsocketProxy) ServeHTTPContext(ctx context.Context, rw http.ResponseWriter, req *http.Request) { + if w.Backend == nil { + log.Println("websocketproxy: backend function is not defined") + http.Error(rw, "internal server error (code: 1)", http.StatusInternalServerError) + return + } + + backendURL := w.Backend(req) + if backendURL == nil { + log.Println("websocketproxy: backend URL is nil") + http.Error(rw, "internal server error (code: 2)", http.StatusInternalServerError) + return + } + + dialer := w.Dialer + if w.Dialer == nil { + dialer = DefaultDialer + } + + // Pass headers from the incoming request to the dialer to forward them to + // the final destinations. + requestHeader := http.Header{} + if origin := req.Header.Get("Origin"); origin != "" { + requestHeader.Add("Origin", origin) + } + for _, prot := range req.Header[http.CanonicalHeaderKey("Sec-WebSocket-Protocol")] { + requestHeader.Add("Sec-WebSocket-Protocol", prot) + } + for _, cookie := range req.Header[http.CanonicalHeaderKey("Cookie")] { + requestHeader.Add("Cookie", cookie) + } + if req.Host != "" { + requestHeader.Set("Host", req.Host) + } + + // Pass X-Forwarded-For headers too, code below is a part of + // httputil.ReverseProxy. See http://en.wikipedia.org/wiki/X-Forwarded-For + // for more information + // TODO: use RFC7239 http://tools.ietf.org/html/rfc7239 + if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { + // If we aren't the first proxy retain prior + // X-Forwarded-For information as a comma+space + // separated list and fold multiple headers into one. + if prior, ok := req.Header["X-Forwarded-For"]; ok { + clientIP = strings.Join(prior, ", ") + ", " + clientIP + } + requestHeader.Set("X-Forwarded-For", clientIP) + } + + // Set the originating protocol of the incoming HTTP request. The SSL might + // be terminated on our site and because we doing proxy adding this would + // be helpful for applications on the backend. + requestHeader.Set("X-Forwarded-Proto", "http") + if req.TLS != nil { + requestHeader.Set("X-Forwarded-Proto", "https") + } + + // Enable the director to copy any additional headers it desires for + // forwarding to the remote server. + if w.Director != nil { + w.Director(req, requestHeader) + } + + // Connect to the backend URL, also pass the headers we get from the requst + // together with the Forwarded headers we prepared above. + // TODO: support multiplexing on the same backend connection instead of + // opening a new TCP connection time for each request. This should be + // optional: + // http://tools.ietf.org/html/draft-ietf-hybi-websocket-multiplexing-01 + connBackend, resp, err := dialer.DialContext(ctx, backendURL.String(), requestHeader) + if err != nil { + log.Printf("websocketproxy: couldn't dial to remote backend url %s, %s", backendURL.String(), err) + if resp != nil { + // If the WebSocket handshake fails, ErrBadHandshake is returned + // along with a non-nil *http.Response so that callers can handle + // redirects, authentication, etcetera. + if err := copyResponse(rw, resp); err != nil { + log.Printf("websocketproxy: couldn't write response after failed remote backend handshake: %s", err) + } + } else { + http.Error(rw, http.StatusText(http.StatusServiceUnavailable), http.StatusServiceUnavailable) + } + return + } + defer connBackend.Close() + + upgrader := w.Upgrader + if w.Upgrader == nil { + upgrader = DefaultUpgrader + } + + // Only pass those headers to the upgrader. + upgradeHeader := http.Header{} + if hdr := resp.Header.Get("Sec-Websocket-Protocol"); hdr != "" { + upgradeHeader.Set("Sec-Websocket-Protocol", hdr) + } + if hdr := resp.Header.Get("Set-Cookie"); hdr != "" { + upgradeHeader.Set("Set-Cookie", hdr) + } + + // Now upgrade the existing incoming request to a WebSocket connection. + // Also pass the header that we gathered from the Dial handshake. + connPub, err := upgrader.Upgrade(rw, req, upgradeHeader) + if err != nil { + log.Printf("websocketproxy: couldn't upgrade %s", err) + return + } + defer connPub.Close() + + errClient := make(chan error, 1) + errBackend := make(chan error, 1) + replicateWebsocketConn := func(dst, src *websocket.Conn, errc chan error) { + for { + msgType, msg, err := src.ReadMessage() + if err != nil { + m := websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("%v", err)) + if e, ok := err.(*websocket.CloseError); ok { + if e.Code != websocket.CloseNoStatusReceived { + m = websocket.FormatCloseMessage(e.Code, e.Text) + } + } + errc <- err + dst.WriteMessage(websocket.CloseMessage, m) + break + } + err = dst.WriteMessage(msgType, msg) + if err != nil { + errc <- err + break + } + } + } + + go replicateWebsocketConn(connPub, connBackend, errClient) + go replicateWebsocketConn(connBackend, connPub, errBackend) + + var message string + select { + case err = <-errClient: + message = "websocketproxy: Error when copying from backend to client: %v" + case err = <-errBackend: + message = "websocketproxy: Error when copying from client to backend: %v" + + } + if e, ok := err.(*websocket.CloseError); !ok || e.Code == websocket.CloseAbnormalClosure { + log.Printf(message, err) + } +} + +func copyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } +} + +func copyResponse(rw http.ResponseWriter, resp *http.Response) error { + copyHeader(rw.Header(), resp.Header) + rw.WriteHeader(resp.StatusCode) + defer resp.Body.Close() + + _, err := io.Copy(rw, resp.Body) + return err +} diff --git a/websocketproxy/websocketproxy_test.go b/websocketproxy/websocketproxy_test.go new file mode 100644 index 0000000..b90e02b --- /dev/null +++ b/websocketproxy/websocketproxy_test.go @@ -0,0 +1,130 @@ +package websocketproxy + +import ( + "log" + "net/http" + "net/url" + "testing" + "time" + + "github.com/gorilla/websocket" +) + +var ( + serverURL = "ws://127.0.0.1:7777" + backendURL = "ws://127.0.0.1:8888" +) + +func TestProxy(t *testing.T) { + // websocket proxy + supportedSubProtocols := []string{"test-protocol"} + upgrader := &websocket.Upgrader{ + ReadBufferSize: 4096, + WriteBufferSize: 4096, + CheckOrigin: func(r *http.Request) bool { + return true + }, + Subprotocols: supportedSubProtocols, + } + + u, _ := url.Parse(backendURL) + proxy := NewProxy(u) + proxy.Upgrader = upgrader + + mux := http.NewServeMux() + mux.Handle("/proxy", proxy) + go func() { + if err := http.ListenAndServe(":7777", mux); err != nil { + t.Fatal("ListenAndServe: ", err) + } + }() + + time.Sleep(time.Millisecond * 100) + + // backend echo server + go func() { + mux2 := http.NewServeMux() + mux2.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + // Don't upgrade if original host header isn't preserved + if r.Host != "127.0.0.1:7777" { + log.Printf("Host header set incorrectly. Expecting 127.0.0.1:7777 got %s", r.Host) + return + } + + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Println(err) + return + } + + messageType, p, err := conn.ReadMessage() + if err != nil { + return + } + + if err = conn.WriteMessage(messageType, p); err != nil { + return + } + }) + + err := http.ListenAndServe(":8888", mux2) + if err != nil { + t.Fatal("ListenAndServe: ", err) + } + }() + + time.Sleep(time.Millisecond * 100) + + // let's us define two subprotocols, only one is supported by the server + clientSubProtocols := []string{"test-protocol", "test-notsupported"} + h := http.Header{} + for _, subprot := range clientSubProtocols { + h.Add("Sec-WebSocket-Protocol", subprot) + } + + // frontend server, dial now our proxy, which will reverse proxy our + // message to the backend websocket server. + conn, resp, err := websocket.DefaultDialer.Dial(serverURL+"/proxy", h) + if err != nil { + t.Fatal(err) + } + + // check if the server really accepted only the first one + in := func(desired string) bool { + for _, prot := range resp.Header[http.CanonicalHeaderKey("Sec-WebSocket-Protocol")] { + if desired == prot { + return true + } + } + return false + } + + if !in("test-protocol") { + t.Error("test-protocol should be available") + } + + if in("test-notsupported") { + t.Error("test-notsupported should be not recevied from the server.") + } + + // now write a message and send it to the backend server (which goes trough + // proxy..) + msg := "hello kite" + err = conn.WriteMessage(websocket.TextMessage, []byte(msg)) + if err != nil { + t.Error(err) + } + + messageType, p, err := conn.ReadMessage() + if err != nil { + t.Error(err) + } + + if messageType != websocket.TextMessage { + t.Error("incoming message type is not Text") + } + + if msg != string(p) { + t.Errorf("expecting: %s, got: %s", msg, string(p)) + } +}