forked from inconshreveable/go-vhost
-
Notifications
You must be signed in to change notification settings - Fork 1
/
mux.go
337 lines (288 loc) · 7.83 KB
/
mux.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
package vhost
import (
"fmt"
"net"
"strings"
"sync"
"time"
)
var (
normalize = strings.ToLower
isClosed = func(err error) bool {
netErr, ok := err.(net.Error)
if ok {
return netErr.Temporary()
}
return false
}
)
// NotFound is returned when a vhost is not found
type NotFound struct {
error
}
// BadRequest is returned when extraction of the vhost name fails
type BadRequest struct {
error
}
// Closed is returned when the underlying connection is closed
type Closed struct {
error
}
type (
// this is the function you apply to a net.Conn to get
// a new virtual-host multiplexed connection
muxFn func(net.Conn) (Conn, error)
// an error encountered when multiplexing a connection
muxErr struct {
err error
conn net.Conn
}
)
type VhostMuxer struct {
listener net.Listener // listener on which we mux connections
muxTimeout time.Duration // a connection fails if it doesn't send enough data to mux after this timeout
vhostFn muxFn // new connections are multiplexed by applying this function
muxErrors chan muxErr // all muxing errors are sent over this channel
registry map[string]*Listener // registry of name -> listener
sync.RWMutex // protects the registry
}
func NewVhostMuxer(listener net.Listener, vhostFn muxFn, muxTimeout time.Duration) (*VhostMuxer, error) {
mux := &VhostMuxer{
listener: listener,
muxTimeout: muxTimeout,
vhostFn: vhostFn,
muxErrors: make(chan muxErr),
registry: make(map[string]*Listener),
}
go mux.run()
return mux, nil
}
// Listen begins multiplexing the underlying connection to send new
// connections for the given name over the returned listener.
func (m *VhostMuxer) Listen(name string) (net.Listener, error) {
name = normalize(name)
vhost := &Listener{
name: name,
mux: m,
accept: make(chan Conn),
}
if err := m.set(name, vhost); err != nil {
return nil, err
}
return vhost, nil
}
// NextError returns the next error encountered while mux'ing a connection.
// The net.Conn may be nil if the wrapped listener returned an error from Accept()
func (m *VhostMuxer) NextError() (net.Conn, error) {
muxErr := <-m.muxErrors
return muxErr.conn, muxErr.err
}
// Close closes the underlying listener
func (m *VhostMuxer) Close() {
m.listener.Close()
}
// run is the VhostMuxer's main loop for accepting new connections from the wrapped listener
func (m *VhostMuxer) run() {
for {
conn, err := m.listener.Accept()
if err != nil {
if isClosed(err) {
m.sendError(nil, Closed{err})
return
} else {
m.sendError(nil, err)
continue
}
}
go m.handle(conn)
}
}
// handle muxes a connection accepted from the listener
func (m *VhostMuxer) handle(conn net.Conn) {
defer func() {
// recover from failures
if r := recover(); r != nil {
m.sendError(conn, fmt.Errorf("NameMux.handle failed with error %v", r))
}
}()
// Make sure we detect dead connections while we decide how to multiplex
if err := conn.SetDeadline(time.Now().Add(m.muxTimeout)); err != nil {
m.sendError(conn, fmt.Errorf("Failed to set deadline: %v", err))
return
}
// extract the name
vconn, err := m.vhostFn(conn)
if err != nil {
m.sendError(conn, BadRequest{fmt.Errorf("Failed to extract vhost name: %v", err)})
return
}
// normalize the name
host := normalize(vconn.Host())
// look up the correct listener
l, ok := m.get(host)
if !ok {
m.sendError(vconn, NotFound{fmt.Errorf("Host not found: %v", host)})
return
}
if err = vconn.SetDeadline(time.Time{}); err != nil {
m.sendError(vconn, fmt.Errorf("Failed unset connection deadline: %v", err))
return
}
l.accept <- vconn
}
func (m *VhostMuxer) sendError(conn net.Conn, err error) {
m.muxErrors <- muxErr{conn: conn, err: err}
}
func (m *VhostMuxer) get(name string) (l *Listener, ok bool) {
m.RLock()
defer m.RUnlock()
l, ok = m.registry[name]
if !ok {
// look for a matching wildcard
parts := strings.Split(name, ".")
for i := 0; i < len(parts)-1; i++ {
parts[i] = "*"
name = strings.Join(parts[i:], ".")
l, ok = m.registry[name]
if ok {
break
}
}
}
return
}
func (m *VhostMuxer) set(name string, l *Listener) error {
m.Lock()
defer m.Unlock()
if _, exists := m.registry[name]; exists {
return fmt.Errorf("name %s is already bound", name)
}
m.registry[name] = l
return nil
}
func (m *VhostMuxer) del(name string) {
m.Lock()
defer m.Unlock()
delete(m.registry, name)
}
const (
serverError = `HTTP/1.0 500 Internal Server Error
Content-Length: 22
Internal Server Error
`
notFound = `HTTP/1.0 404 Not Found
Content-Length: 14
404 not found
`
badRequest = `HTTP/1.0 400 Bad Request
Content-Length: 12
Bad Request
`
)
type HTTPMuxer struct {
*VhostMuxer
}
// HandleErrors handles muxing errors by calling .NextError(). You must
// invoke this function if you do not want to handle the errors yourself.
func (m *HTTPMuxer) HandleErrors() {
for {
m.HandleError(m.NextError())
}
}
func (m *HTTPMuxer) HandleError(conn net.Conn, err error) {
switch err.(type) {
case Closed:
return
case NotFound:
conn.Write([]byte(notFound))
case BadRequest:
conn.Write([]byte(badRequest))
default:
if conn != nil {
conn.Write([]byte(serverError))
}
}
if conn != nil {
conn.Close()
}
}
// NewHTTPMuxer begins muxing HTTP connections on the given listener by inspecting
// the HTTP Host header in new connections.
func NewHTTPMuxer(listener net.Listener, muxTimeout time.Duration) (*HTTPMuxer, error) {
fn := func(c net.Conn) (Conn, error) { return HTTP(c) }
mux, err := NewVhostMuxer(listener, fn, muxTimeout)
return &HTTPMuxer{mux}, err
}
type TLSMuxer struct {
*VhostMuxer
}
// HandleErrors is the default error handler for TLS muxers. At the moment, it simply
// closes connections which are invalid or destined for virtual host names that it is
// not listening for.
// You must invoke this function if you do not want to handle the errors yourself.
func (m *TLSMuxer) HandleErrors() {
for {
conn, err := m.NextError()
if conn == nil {
if _, ok := err.(Closed); ok {
return
} else {
continue
}
} else {
// XXX: respond with valid TLS close messages
conn.Close()
}
}
}
func (m *TLSMuxer) Listen(name string) (net.Listener, error) {
// TLS SNI never includes the port
host, _, err := net.SplitHostPort(name)
if err != nil {
host = name
}
return m.VhostMuxer.Listen(host)
}
// NewTLSMuxer begins muxing TLS connections by inspecting the SNI extension.
func NewTLSMuxer(listener net.Listener, muxTimeout time.Duration) (*TLSMuxer, error) {
fn := func(c net.Conn) (Conn, error) { return TLS(c) }
mux, err := NewVhostMuxer(listener, fn, muxTimeout)
return &TLSMuxer{mux}, err
}
// Listener is returned by a call to Listen() on a muxer. A Listener
// only receives connections that were made to the name passed into the muxer's
// Listen call.
//
// Listener implements the net.Listener interface, so you can Accept() new
// connections and Close() it when finished. When you Close() a Listener,
// the parent muxer will stop listening for connections to the Listener's name.
type Listener struct {
name string
mux *VhostMuxer
accept chan Conn
}
// Accept returns the next mux'd connection for this listener and blocks
// until one is available.
func (l *Listener) Accept() (net.Conn, error) {
conn, ok := <-l.accept
if !ok {
return nil, fmt.Errorf("Listener closed")
}
return conn, nil
}
// Close stops the parent muxer from listening for connections to the mux'd
// virtual host name.
func (l *Listener) Close() error {
l.mux.del(l.name)
close(l.accept)
return nil
}
// Addr returns the address of the bound listener used by the parent muxer.
func (l *Listener) Addr() net.Addr {
// XXX: include name in address?
return l.mux.listener.Addr()
}
// Name returns the name of the virtual host this listener receives connections on.
func (l *Listener) Name() string {
return l.name
}