Skip to content

Commit dd17c7b

Browse files
committed
treewide: replace gorilla/mux with http.ServeMux
Signed-off-by: Sumner Evans <[email protected]>
1 parent 66c4178 commit dd17c7b

File tree

11 files changed

+170
-185
lines changed

11 files changed

+170
-185
lines changed

appservice/appservice.go

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ import (
1919
"syscall"
2020
"time"
2121

22-
"github.com/gorilla/mux"
2322
"github.com/gorilla/websocket"
2423
"github.com/rs/zerolog"
2524
"golang.org/x/net/publicsuffix"
@@ -43,7 +42,7 @@ func Create() *AppService {
4342
intents: make(map[id.UserID]*IntentAPI),
4443
HTTPClient: &http.Client{Timeout: 180 * time.Second, Jar: jar},
4544
StateStore: mautrix.NewMemoryStateStore().(StateStore),
46-
Router: mux.NewRouter(),
45+
Router: http.NewServeMux(),
4746
UserAgent: mautrix.DefaultUserAgent,
4847
txnIDC: NewTransactionIDCache(128),
4948
Live: true,
@@ -61,12 +60,12 @@ func Create() *AppService {
6160
DefaultHTTPRetries: 4,
6261
}
6362

64-
as.Router.HandleFunc("/_matrix/app/v1/transactions/{txnID}", as.PutTransaction).Methods(http.MethodPut)
65-
as.Router.HandleFunc("/_matrix/app/v1/rooms/{roomAlias}", as.GetRoom).Methods(http.MethodGet)
66-
as.Router.HandleFunc("/_matrix/app/v1/users/{userID}", as.GetUser).Methods(http.MethodGet)
67-
as.Router.HandleFunc("/_matrix/app/v1/ping", as.PostPing).Methods(http.MethodPost)
68-
as.Router.HandleFunc("/_matrix/mau/live", as.GetLive).Methods(http.MethodGet)
69-
as.Router.HandleFunc("/_matrix/mau/ready", as.GetReady).Methods(http.MethodGet)
63+
as.Router.HandleFunc("PUT /_matrix/app/v1/transactions/{txnID}", as.PutTransaction)
64+
as.Router.HandleFunc("GET /_matrix/app/v1/rooms/{roomAlias}", as.GetRoom)
65+
as.Router.HandleFunc("GET /_matrix/app/v1/users/{userID}", as.GetUser)
66+
as.Router.HandleFunc("POST /_matrix/app/v1/ping", as.PostPing)
67+
as.Router.HandleFunc("GET /_matrix/mau/live", as.GetLive)
68+
as.Router.HandleFunc("GET /_matrix/mau/ready", as.GetReady)
7069

7170
return as
7271
}
@@ -160,7 +159,7 @@ type AppService struct {
160159
QueryHandler QueryHandler
161160
StateStore StateStore
162161

163-
Router *mux.Router
162+
Router *http.ServeMux
164163
UserAgent string
165164
server *http.Server
166165
HTTPClient *http.Client

appservice/http.go

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ import (
1717
"syscall"
1818
"time"
1919

20-
"github.com/gorilla/mux"
2120
"github.com/rs/zerolog"
2221

2322
"maunium.net/go/mautrix"
@@ -106,8 +105,7 @@ func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) {
106105
return
107106
}
108107

109-
vars := mux.Vars(r)
110-
txnID := vars["txnID"]
108+
txnID := r.PathValue("txnID")
111109
if len(txnID) == 0 {
112110
Error{
113111
ErrorCode: ErrNoTransactionID,
@@ -263,9 +261,7 @@ func (as *AppService) GetRoom(w http.ResponseWriter, r *http.Request) {
263261
return
264262
}
265263

266-
vars := mux.Vars(r)
267-
roomAlias := vars["roomAlias"]
268-
ok := as.QueryHandler.QueryAlias(roomAlias)
264+
ok := as.QueryHandler.QueryAlias(r.PathValue("roomAlias"))
269265
if ok {
270266
WriteBlankOK(w)
271267
} else {
@@ -282,9 +278,7 @@ func (as *AppService) GetUser(w http.ResponseWriter, r *http.Request) {
282278
return
283279
}
284280

285-
vars := mux.Vars(r)
286-
userID := id.UserID(vars["userID"])
287-
ok := as.QueryHandler.QueryUser(userID)
281+
ok := as.QueryHandler.QueryUser(id.UserID(r.PathValue("userID")))
288282
if ok {
289283
WriteBlankOK(w)
290284
} else {

bridgev2/matrix/connector.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"encoding/json"
1414
"errors"
1515
"fmt"
16+
"net/http"
1617
"net/url"
1718
"os"
1819
"regexp"
@@ -21,7 +22,6 @@ import (
2122
"time"
2223
"unsafe"
2324

24-
"github.com/gorilla/mux"
2525
_ "github.com/lib/pq"
2626
"github.com/rs/zerolog"
2727
"go.mau.fi/util/dbutil"
@@ -216,7 +216,8 @@ func (br *Connector) GetPublicAddress() string {
216216
return br.Config.AppService.PublicAddress
217217
}
218218

219-
func (br *Connector) GetRouter() *mux.Router {
219+
// TODO switch to http.ServeMux
220+
func (br *Connector) GetRouter() *http.ServeMux {
220221
if br.GetPublicAddress() != "" {
221222
return br.AS.Router
222223
}

bridgev2/matrix/provisioning.go

Lines changed: 47 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ import (
1616
"sync"
1717
"time"
1818

19-
"github.com/gorilla/mux"
2019
"github.com/rs/xid"
2120
"github.com/rs/zerolog"
2221
"github.com/rs/zerolog/hlog"
22+
"go.mau.fi/util/exhttp"
2323
"go.mau.fi/util/jsontime"
2424
"go.mau.fi/util/requestlog"
2525

@@ -37,7 +37,7 @@ type matrixAuthCacheEntry struct {
3737
}
3838

3939
type ProvisioningAPI struct {
40-
Router *mux.Router
40+
Router *http.ServeMux
4141

4242
br *Connector
4343
log zerolog.Logger
@@ -72,12 +72,12 @@ func (prov *ProvisioningAPI) GetUser(r *http.Request) *bridgev2.User {
7272
return r.Context().Value(provisioningUserKey).(*bridgev2.User)
7373
}
7474

75-
func (prov *ProvisioningAPI) GetRouter() *mux.Router {
75+
func (prov *ProvisioningAPI) GetRouter() *http.ServeMux {
7676
return prov.Router
7777
}
7878

7979
type IProvisioningAPI interface {
80-
GetRouter() *mux.Router
80+
GetRouter() *http.ServeMux
8181
GetUser(r *http.Request) *bridgev2.User
8282
}
8383

@@ -96,44 +96,38 @@ func (prov *ProvisioningAPI) Init() {
9696
tp.Dialer.Timeout = 10 * time.Second
9797
tp.Transport.ResponseHeaderTimeout = 10 * time.Second
9898
tp.Transport.TLSHandshakeTimeout = 10 * time.Second
99-
prov.Router = prov.br.AS.Router.PathPrefix(prov.br.Config.Provisioning.Prefix).Subrouter()
100-
prov.Router.Use(hlog.NewHandler(prov.log))
101-
prov.Router.Use(corsMiddleware)
102-
prov.Router.Use(requestlog.AccessLogger(false))
103-
prov.Router.Use(prov.AuthMiddleware)
104-
prov.Router.Path("/v3/whoami").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetWhoami)
105-
prov.Router.Path("/v3/login/flows").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetLoginFlows)
106-
prov.Router.Path("/v3/login/start/{flowID}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginStart)
107-
prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:user_input|cookies}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginSubmitInput)
108-
prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:display_and_wait}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginWait)
109-
prov.Router.Path("/v3/logout/{loginID}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLogout)
110-
prov.Router.Path("/v3/logins").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetLogins)
111-
prov.Router.Path("/v3/contacts").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetContactList)
112-
prov.Router.Path("/v3/resolve_identifier/{identifier}").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetResolveIdentifier)
113-
prov.Router.Path("/v3/create_dm/{identifier}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostCreateDM)
114-
prov.Router.Path("/v3/create_group").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostCreateGroup)
99+
100+
provRouter := http.NewServeMux()
101+
102+
provRouter.HandleFunc("GET /v3/whoami", prov.GetWhoami)
103+
provRouter.HandleFunc("GET /v3/whoami/flows", prov.GetLoginFlows)
104+
105+
provRouter.HandleFunc("POST /v3/login/start/{flowID}", prov.PostLoginStart)
106+
provRouter.HandleFunc("POST /v3/login/step/{loginProcessID}/{stepID}/{stepType}", prov.PostLogin)
107+
provRouter.HandleFunc("POST /v3/logout/{loginID}", prov.PostLogout)
108+
provRouter.HandleFunc("GET /v3/logins", prov.GetLogins)
109+
provRouter.HandleFunc("GET /v3/contacts", prov.GetContactList)
110+
provRouter.HandleFunc("GET /v3/resolve_identifier/{identifier}", prov.GetResolveIdentifier)
111+
provRouter.HandleFunc("POST /v3/create_dm/{identifier}", prov.PostCreateDM)
112+
provRouter.HandleFunc("POST /v3/create_group", prov.PostCreateGroup)
113+
114+
var provHandler http.Handler = prov.Router
115+
provHandler = prov.AuthMiddleware(provHandler)
116+
provHandler = requestlog.AccessLogger(false)(provHandler)
117+
provHandler = exhttp.CORSMiddleware(provHandler)
118+
provHandler = hlog.NewHandler(prov.log)(provHandler)
119+
provHandler = http.StripPrefix(prov.br.Config.Provisioning.Prefix, provHandler)
120+
prov.br.AS.Router.Handle(prov.br.Config.Provisioning.Prefix, provHandler)
115121

116122
if prov.br.Config.Provisioning.DebugEndpoints {
117123
prov.log.Debug().Msg("Enabling debug API at /debug")
118-
r := prov.br.AS.Router.PathPrefix("/debug").Subrouter()
119-
r.Use(prov.AuthMiddleware)
120-
r.PathPrefix("/pprof").Handler(http.DefaultServeMux)
124+
debugRouter := http.NewServeMux()
125+
// TODO do we need to strip prefix here?
126+
debugRouter.Handle("/debug/pprof", http.StripPrefix("/debug/pprof", http.DefaultServeMux))
127+
prov.br.AS.Router.Handle("/debug", prov.AuthMiddleware(debugRouter))
121128
}
122129
}
123130

124-
func corsMiddleware(handler http.Handler) http.Handler {
125-
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
126-
w.Header().Set("Access-Control-Allow-Origin", "*")
127-
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
128-
w.Header().Set("Access-Control-Allow-Headers", "X-Requested-With, Content-Type, Authorization")
129-
if r.Method == http.MethodOptions {
130-
w.WriteHeader(http.StatusOK)
131-
return
132-
}
133-
handler.ServeHTTP(w, r)
134-
})
135-
}
136-
137131
func jsonResponse(w http.ResponseWriter, status int, response any) {
138132
w.Header().Add("Content-Type", "application/json")
139133
w.WriteHeader(status)
@@ -221,7 +215,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
221215
// TODO handle user being nil?
222216

223217
ctx := context.WithValue(r.Context(), provisioningUserKey, user)
224-
if loginID, ok := mux.Vars(r)["loginProcessID"]; ok {
218+
if loginID := r.PathValue("loginProcessID"); loginID != "" {
225219
prov.loginsLock.RLock()
226220
login, ok := prov.logins[loginID]
227221
prov.loginsLock.RUnlock()
@@ -236,7 +230,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
236230
login.Lock.Lock()
237231
// This will only unlock after the handler runs
238232
defer login.Lock.Unlock()
239-
stepID := mux.Vars(r)["stepID"]
233+
stepID := r.PathValue("stepID")
240234
if login.NextStep.StepID != stepID {
241235
zerolog.Ctx(r.Context()).Warn().
242236
Str("request_step_id", stepID).
@@ -248,7 +242,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
248242
})
249243
return
250244
}
251-
stepType := mux.Vars(r)["stepType"]
245+
stepType := r.PathValue("stepType")
252246
if login.NextStep.Type != bridgev2.LoginStepType(stepType) {
253247
zerolog.Ctx(r.Context()).Warn().
254248
Str("request_step_type", stepType).
@@ -352,7 +346,7 @@ func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Reque
352346
login, err := prov.net.CreateLogin(
353347
r.Context(),
354348
prov.GetUser(r),
355-
mux.Vars(r)["flowID"],
349+
r.PathValue("flowID"),
356350
)
357351
if err != nil {
358352
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create login process")
@@ -391,6 +385,17 @@ func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *Prov
391385
}, bridgev2.DeleteOpts{LogoutRemote: true})
392386
}
393387

388+
func (prov *ProvisioningAPI) PostLogin(w http.ResponseWriter, r *http.Request) {
389+
switch r.PathValue("stepType") {
390+
case "user_input", "cookies":
391+
prov.PostLoginSubmitInput(w, r)
392+
case "display_and_wait":
393+
prov.PostLoginWait(w, r)
394+
default:
395+
panic("Impossible state") // checked by the AuthMiddleware
396+
}
397+
}
398+
394399
func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http.Request) {
395400
var params map[string]string
396401
err := json.NewDecoder(r.Body).Decode(&params)
@@ -444,7 +449,7 @@ func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Reques
444449

445450
func (prov *ProvisioningAPI) PostLogout(w http.ResponseWriter, r *http.Request) {
446451
user := prov.GetUser(r)
447-
userLoginID := networkid.UserLoginID(mux.Vars(r)["loginID"])
452+
userLoginID := networkid.UserLoginID(r.PathValue("loginID"))
448453
if userLoginID == "all" {
449454
for {
450455
login := user.GetDefaultLogin()
@@ -548,7 +553,7 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http.
548553
})
549554
return
550555
}
551-
resp, err := api.ResolveIdentifier(r.Context(), mux.Vars(r)["identifier"], createChat)
556+
resp, err := api.ResolveIdentifier(r.Context(), r.PathValue("identifier"), createChat)
552557
if err != nil {
553558
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to resolve identifier")
554559
respondMaybeCustomError(w, err, "Internal error resolving identifier")

bridgev2/matrix/publicmedia.go

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ import (
1616
"net/http"
1717
"time"
1818

19-
"github.com/gorilla/mux"
20-
2119
"maunium.net/go/mautrix/bridgev2"
2220
"maunium.net/go/mautrix/id"
2321
)
@@ -35,7 +33,7 @@ func (br *Connector) initPublicMedia() error {
3533
return fmt.Errorf("public media hash length is negative")
3634
}
3735
br.pubMediaSigKey = []byte(br.Config.PublicMedia.SigningKey)
38-
br.AS.Router.HandleFunc("/_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia).Methods(http.MethodGet)
36+
br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia)
3937
return nil
4038
}
4139

@@ -76,16 +74,15 @@ var proxyHeadersToCopy = []string{
7674
}
7775

7876
func (br *Connector) servePublicMedia(w http.ResponseWriter, r *http.Request) {
79-
vars := mux.Vars(r)
8077
contentURI := id.ContentURI{
81-
Homeserver: vars["server"],
82-
FileID: vars["mediaID"],
78+
Homeserver: r.PathValue("server"),
79+
FileID: r.PathValue("mediaID"),
8380
}
8481
if !contentURI.IsValid() {
8582
http.Error(w, "invalid content URI", http.StatusBadRequest)
8683
return
8784
}
88-
checksum, err := base64.RawURLEncoding.DecodeString(vars["checksum"])
85+
checksum, err := base64.RawURLEncoding.DecodeString(r.PathValue("checksum"))
8986
if err != nil || !hmac.Equal(checksum, br.makePublicMediaChecksum(contentURI)) {
9087
http.Error(w, "invalid base64 in checksum", http.StatusBadRequest)
9188
return

bridgev2/matrixinterface.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,9 @@ package bridgev2
99
import (
1010
"context"
1111
"io"
12+
"net/http"
1213
"time"
1314

14-
"github.com/gorilla/mux"
15-
1615
"maunium.net/go/mautrix"
1716
"maunium.net/go/mautrix/bridge/status"
1817
"maunium.net/go/mautrix/bridgev2/database"
@@ -56,7 +55,7 @@ type MatrixConnector interface {
5655

5756
type MatrixConnectorWithServer interface {
5857
GetPublicAddress() string
59-
GetRouter() *mux.Router
58+
GetRouter() *http.ServeMux
6059
}
6160

6261
type MatrixConnectorWithPublicMedia interface {

0 commit comments

Comments
 (0)