@@ -16,10 +16,10 @@ import (
1616 "sync"
1717 "time"
1818
19- "github.com/gorilla/mux"
2019 "github.com/rs/xid"
2120 "github.com/rs/zerolog"
2221 "github.com/rs/zerolog/hlog"
22+ "go.mau.fi/util/exhttp"
2323 "go.mau.fi/util/jsontime"
2424 "go.mau.fi/util/requestlog"
2525
@@ -37,7 +37,7 @@ type matrixAuthCacheEntry struct {
3737}
3838
3939type ProvisioningAPI struct {
40- Router * mux. Router
40+ Router * http. ServeMux
4141
4242 br * Connector
4343 log zerolog.Logger
@@ -72,12 +72,12 @@ func (prov *ProvisioningAPI) GetUser(r *http.Request) *bridgev2.User {
7272 return r .Context ().Value (provisioningUserKey ).(* bridgev2.User )
7373}
7474
75- func (prov * ProvisioningAPI ) GetRouter () * mux. Router {
75+ func (prov * ProvisioningAPI ) GetRouter () * http. ServeMux {
7676 return prov .Router
7777}
7878
7979type IProvisioningAPI interface {
80- GetRouter () * mux. Router
80+ GetRouter () * http. ServeMux
8181 GetUser (r * http.Request ) * bridgev2.User
8282}
8383
@@ -96,44 +96,38 @@ func (prov *ProvisioningAPI) Init() {
9696 tp .Dialer .Timeout = 10 * time .Second
9797 tp .Transport .ResponseHeaderTimeout = 10 * time .Second
9898 tp .Transport .TLSHandshakeTimeout = 10 * time .Second
99- prov .Router = prov .br .AS .Router .PathPrefix (prov .br .Config .Provisioning .Prefix ).Subrouter ()
100- prov .Router .Use (hlog .NewHandler (prov .log ))
101- prov .Router .Use (corsMiddleware )
102- prov .Router .Use (requestlog .AccessLogger (false ))
103- prov .Router .Use (prov .AuthMiddleware )
104- prov .Router .Path ("/v3/whoami" ).Methods (http .MethodGet , http .MethodOptions ).HandlerFunc (prov .GetWhoami )
105- prov .Router .Path ("/v3/login/flows" ).Methods (http .MethodGet , http .MethodOptions ).HandlerFunc (prov .GetLoginFlows )
106- prov .Router .Path ("/v3/login/start/{flowID}" ).Methods (http .MethodPost , http .MethodOptions ).HandlerFunc (prov .PostLoginStart )
107- prov .Router .Path ("/v3/login/step/{loginProcessID}/{stepID}/{stepType:user_input|cookies}" ).Methods (http .MethodPost , http .MethodOptions ).HandlerFunc (prov .PostLoginSubmitInput )
108- prov .Router .Path ("/v3/login/step/{loginProcessID}/{stepID}/{stepType:display_and_wait}" ).Methods (http .MethodPost , http .MethodOptions ).HandlerFunc (prov .PostLoginWait )
109- prov .Router .Path ("/v3/logout/{loginID}" ).Methods (http .MethodPost , http .MethodOptions ).HandlerFunc (prov .PostLogout )
110- prov .Router .Path ("/v3/logins" ).Methods (http .MethodGet , http .MethodOptions ).HandlerFunc (prov .GetLogins )
111- prov .Router .Path ("/v3/contacts" ).Methods (http .MethodGet , http .MethodOptions ).HandlerFunc (prov .GetContactList )
112- prov .Router .Path ("/v3/resolve_identifier/{identifier}" ).Methods (http .MethodGet , http .MethodOptions ).HandlerFunc (prov .GetResolveIdentifier )
113- prov .Router .Path ("/v3/create_dm/{identifier}" ).Methods (http .MethodPost , http .MethodOptions ).HandlerFunc (prov .PostCreateDM )
114- prov .Router .Path ("/v3/create_group" ).Methods (http .MethodPost , http .MethodOptions ).HandlerFunc (prov .PostCreateGroup )
99+
100+ provRouter := http .NewServeMux ()
101+
102+ provRouter .HandleFunc ("GET /v3/whoami" , prov .GetWhoami )
103+ provRouter .HandleFunc ("GET /v3/whoami/flows" , prov .GetLoginFlows )
104+
105+ provRouter .HandleFunc ("POST /v3/login/start/{flowID}" , prov .PostLoginStart )
106+ provRouter .HandleFunc ("POST /v3/login/step/{loginProcessID}/{stepID}/{stepType}" , prov .PostLogin )
107+ provRouter .HandleFunc ("POST /v3/logout/{loginID}" , prov .PostLogout )
108+ provRouter .HandleFunc ("GET /v3/logins" , prov .GetLogins )
109+ provRouter .HandleFunc ("GET /v3/contacts" , prov .GetContactList )
110+ provRouter .HandleFunc ("GET /v3/resolve_identifier/{identifier}" , prov .GetResolveIdentifier )
111+ provRouter .HandleFunc ("POST /v3/create_dm/{identifier}" , prov .PostCreateDM )
112+ provRouter .HandleFunc ("POST /v3/create_group" , prov .PostCreateGroup )
113+
114+ var provHandler http.Handler = prov .Router
115+ provHandler = prov .AuthMiddleware (provHandler )
116+ provHandler = requestlog .AccessLogger (false )(provHandler )
117+ provHandler = exhttp .CORSMiddleware (provHandler )
118+ provHandler = hlog .NewHandler (prov .log )(provHandler )
119+ provHandler = http .StripPrefix (prov .br .Config .Provisioning .Prefix , provHandler )
120+ prov .br .AS .Router .Handle (prov .br .Config .Provisioning .Prefix , provHandler )
115121
116122 if prov .br .Config .Provisioning .DebugEndpoints {
117123 prov .log .Debug ().Msg ("Enabling debug API at /debug" )
118- r := prov .br .AS .Router .PathPrefix ("/debug" ).Subrouter ()
119- r .Use (prov .AuthMiddleware )
120- r .PathPrefix ("/pprof" ).Handler (http .DefaultServeMux )
124+ debugRouter := http .NewServeMux ()
125+ // TODO do we need to strip prefix here?
126+ debugRouter .Handle ("/debug/pprof" , http .StripPrefix ("/debug/pprof" , http .DefaultServeMux ))
127+ prov .br .AS .Router .Handle ("/debug" , prov .AuthMiddleware (debugRouter ))
121128 }
122129}
123130
124- func corsMiddleware (handler http.Handler ) http.Handler {
125- return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
126- w .Header ().Set ("Access-Control-Allow-Origin" , "*" )
127- w .Header ().Set ("Access-Control-Allow-Methods" , "GET, POST, PUT, DELETE, OPTIONS" )
128- w .Header ().Set ("Access-Control-Allow-Headers" , "X-Requested-With, Content-Type, Authorization" )
129- if r .Method == http .MethodOptions {
130- w .WriteHeader (http .StatusOK )
131- return
132- }
133- handler .ServeHTTP (w , r )
134- })
135- }
136-
137131func jsonResponse (w http.ResponseWriter , status int , response any ) {
138132 w .Header ().Add ("Content-Type" , "application/json" )
139133 w .WriteHeader (status )
@@ -221,7 +215,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
221215 // TODO handle user being nil?
222216
223217 ctx := context .WithValue (r .Context (), provisioningUserKey , user )
224- if loginID , ok := mux . Vars ( r )[ "loginProcessID" ]; ok {
218+ if loginID := r . PathValue ( "loginProcessID" ); loginID != "" {
225219 prov .loginsLock .RLock ()
226220 login , ok := prov .logins [loginID ]
227221 prov .loginsLock .RUnlock ()
@@ -236,7 +230,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
236230 login .Lock .Lock ()
237231 // This will only unlock after the handler runs
238232 defer login .Lock .Unlock ()
239- stepID := mux . Vars ( r )[ "stepID" ]
233+ stepID := r . PathValue ( "stepID" )
240234 if login .NextStep .StepID != stepID {
241235 zerolog .Ctx (r .Context ()).Warn ().
242236 Str ("request_step_id" , stepID ).
@@ -248,7 +242,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
248242 })
249243 return
250244 }
251- stepType := mux . Vars ( r )[ "stepType" ]
245+ stepType := r . PathValue ( "stepType" )
252246 if login .NextStep .Type != bridgev2 .LoginStepType (stepType ) {
253247 zerolog .Ctx (r .Context ()).Warn ().
254248 Str ("request_step_type" , stepType ).
@@ -352,7 +346,7 @@ func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Reque
352346 login , err := prov .net .CreateLogin (
353347 r .Context (),
354348 prov .GetUser (r ),
355- mux . Vars ( r )[ "flowID" ] ,
349+ r . PathValue ( "flowID" ) ,
356350 )
357351 if err != nil {
358352 zerolog .Ctx (r .Context ()).Err (err ).Msg ("Failed to create login process" )
@@ -391,6 +385,17 @@ func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *Prov
391385 }, bridgev2.DeleteOpts {LogoutRemote : true })
392386}
393387
388+ func (prov * ProvisioningAPI ) PostLogin (w http.ResponseWriter , r * http.Request ) {
389+ switch r .PathValue ("stepType" ) {
390+ case "user_input" , "cookies" :
391+ prov .PostLoginSubmitInput (w , r )
392+ case "display_and_wait" :
393+ prov .PostLoginWait (w , r )
394+ default :
395+ panic ("Impossible state" ) // checked by the AuthMiddleware
396+ }
397+ }
398+
394399func (prov * ProvisioningAPI ) PostLoginSubmitInput (w http.ResponseWriter , r * http.Request ) {
395400 var params map [string ]string
396401 err := json .NewDecoder (r .Body ).Decode (& params )
@@ -444,7 +449,7 @@ func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Reques
444449
445450func (prov * ProvisioningAPI ) PostLogout (w http.ResponseWriter , r * http.Request ) {
446451 user := prov .GetUser (r )
447- userLoginID := networkid .UserLoginID (mux . Vars ( r )[ "loginID" ] )
452+ userLoginID := networkid .UserLoginID (r . PathValue ( "loginID" ) )
448453 if userLoginID == "all" {
449454 for {
450455 login := user .GetDefaultLogin ()
@@ -548,7 +553,7 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http.
548553 })
549554 return
550555 }
551- resp , err := api .ResolveIdentifier (r .Context (), mux . Vars ( r )[ "identifier" ] , createChat )
556+ resp , err := api .ResolveIdentifier (r .Context (), r . PathValue ( "identifier" ) , createChat )
552557 if err != nil {
553558 zerolog .Ctx (r .Context ()).Err (err ).Msg ("Failed to resolve identifier" )
554559 respondMaybeCustomError (w , err , "Internal error resolving identifier" )
0 commit comments