Skip to content
89 changes: 88 additions & 1 deletion api/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func (a *API) adminUsers(w http.ResponseWriter, r *http.Request) error {
return badRequestError("Bad Pagination Parameters: %v", err)
}

sortParams, err := sort(r, map[string]bool{models.CreatedAt: true}, []models.SortField{models.SortField{Name: models.CreatedAt, Dir: models.Descending}})
sortParams, err := sort(r, map[string]bool{models.CreatedAt: true}, []models.SortField{{Name: models.CreatedAt, Dir: models.Descending}})
if err != nil {
return badRequestError("Bad Sort Parameters: %v", err)
}
Expand Down Expand Up @@ -345,3 +345,90 @@ func (a *API) adminUserDelete(w http.ResponseWriter, r *http.Request) error {

return sendJSON(w, http.StatusOK, map[string]interface{}{})
}

type adminUserLabelsParams struct {
Label string `json:"label"`
State string `json:"state"`
}

func (a *API) loadUserLabels(w http.ResponseWriter, req *http.Request) (context.Context, error) {
ctx := req.Context()
user := getUser(ctx)

labels, err := models.FindUserLabels(a.db, user.ID)
if err != nil {
return ctx, internalServerError("Error loading user labels").WithInternalError(err)
}

return withLabels(ctx, labels), nil
}

func (a *API) getAdminUserLabelsParams(r *http.Request) (*adminUserLabelsParams, error) {
params := &adminUserLabelsParams{}
err := json.NewDecoder(r.Body).Decode(&params)
if err != nil {
return nil, badRequestError("Could not decode admin user label params: %v", err)
}
return params, nil
}

func (a *API) adminUserLabelsCreateOrUpdate(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
user := getUser(ctx)
instanceID := getInstanceID(ctx)
adminUser := getAdminUser(ctx)
existingLabels := getLabels(ctx)
config := getConfig(ctx)
params, err := a.getAdminUserLabelsParams(r)
if err != nil {
return err
}

// check if requested label is in the list of configured labels
exists := false
for _, level := range config.UserLabels {
for _, label := range level.Labels {
if label == params.Label {
exists = true
break
}
}
}

if !exists {
return badRequestError("Label '%s' is not defined in the config", params.Label)
}

// perform update
err = a.db.Transaction(func(tx *storage.Connection) error {
var action models.AuditAction

if label, ok := existingLabels[params.Label]; ok {
action = models.UserLabelModifiedAction

if terr := label.UpdateState(tx, params.State); terr != nil {
return terr
}
} else {
action = models.UserLabelCreatedAction
newLabel := models.NewUserLabel(user.ID, params.Label, params.State)

if terr := tx.Create(newLabel); terr != nil {
return terr
}

existingLabels[newLabel.Label] = newLabel
}

if terr := models.NewAuditLogEntry(tx, instanceID, adminUser, action, map[string]interface{}{
"user_id": user.ID,
"label_name": params.Label,
"label_state": params.State,
}); terr != nil {
return internalServerError("Error recording audit log entry").WithInternalError(terr)
}
return nil
})

return err
}
93 changes: 91 additions & 2 deletions api/admin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"encoding/json"
"fmt"
"math"
"net/http"
"net/http/httptest"
"testing"
Expand Down Expand Up @@ -43,7 +44,7 @@ func TestAdmin(t *testing.T) {
}

func (ts *AdminTestSuite) SetupTest() {
models.TruncateAll(ts.API.db)
_ = models.TruncateAll(ts.API.db)
ts.Config.External.Email.Enabled = true
ts.token = ts.makeSuperAdmin("")
}
Expand All @@ -54,7 +55,6 @@ func (ts *AdminTestSuite) makeSuperAdmin(email string) string {

u.Role = "supabase_admin"


key, err := models.FindMainAsymmetricKeyByUser(ts.API.db, u)
require.NoError(ts.T(), err, "Error finding keys")

Expand Down Expand Up @@ -587,6 +587,8 @@ func (ts *AdminTestSuite) TestAdminUserCreateWithDisabledLogin() {
},
}

configBackup := *ts.Config

for _, c := range cases {
ts.Run(c.desc, func() {
// Initialize user data
Expand All @@ -603,4 +605,91 @@ func (ts *AdminTestSuite) TestAdminUserCreateWithDisabledLogin() {
require.Equal(ts.T(), c.expected, w.Code)
})
}

*ts.Config = configBackup
}

// TestAdminUserLabelsCreateOrUpdate tests API /admin/user/labels route (POST)
func (ts *AdminTestSuite) TestAdminUserLabelsCreateOrUpdate() {
cases := []struct {
desc string
params map[string]interface{}
expected map[string]interface{}
}{
{
desc: "Create new label",
params: map[string]interface{}{
"label": "email",
"state": "pending",
},
expected: map[string]interface{}{
"httpStatusCode": http.StatusOK,
"new_labels": 1,
},
},
{
desc: "Update existing label",
params: map[string]interface{}{
"label": "email",
"state": "verified",
},
expected: map[string]interface{}{
"httpStatusCode": http.StatusOK,
"new_labels": 0,
},
},
{
desc: "Label does not exist",
params: map[string]interface{}{
"label": "test",
"state": "verified",
},
expected: map[string]interface{}{
"httpStatusCode": http.StatusBadRequest,
"new_labels": 0,
},
},
}

for _, c := range cases {
ts.Run(c.desc, func() {
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.params))

u, err := models.NewUser(ts.instanceID, "[email protected]", "test", ts.Config.JWT.Aud, nil)
require.NoError(ts.T(), err, "Error making new user")
if err := ts.API.db.Create(u); err != nil {
u, err = models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "[email protected]", ts.Config.JWT.Aud)
require.NoError(ts.T(), err, "Error finding user")
}

beforeLabels, err := models.FindUserLabels(ts.API.db, u.ID)
require.NoError(ts.T(), err, "Error loading user labels")

// Setup request
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("/admin/users/%v/labels", u.ID), &buffer)

req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token))
ts.Config.External.Phone.Enabled = true

ts.API.handler.ServeHTTP(w, req)
require.Equal(ts.T(), c.expected["httpStatusCode"], w.Code)

// verify request results
afterLabels, err := models.FindUserLabels(ts.API.db, u.ID)
require.NoError(ts.T(), err, "Error loading user labels")

labelsDelta := len(beforeLabels) - len(afterLabels)
require.Equal(ts.T(), int(math.Abs(float64(labelsDelta))),
c.expected["new_labels"].(int),
fmt.Sprintf("Expected %d label", c.expected["new_labels"]))

if c.expected["new_labels"].(int) > 0 {
labelName := c.params["label"].(string)
require.Equal(ts.T(), afterLabels[labelName].Label, labelName)
require.Equal(ts.T(), afterLabels[labelName].State, c.params["state"].(string))
}
})
}
}
10 changes: 9 additions & 1 deletion api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ func (a *API) ListenAndServe(hostAndPort string) {
waitForTermination(log, done)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
server.Shutdown(ctx)
if err := server.Shutdown(ctx); err != nil {
log.WithError(err).Fatal("http server shutdown failed")
}
}()

if err := server.ListenAndServe(); err != http.ErrServerClosed {
Expand Down Expand Up @@ -166,6 +168,12 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati
r.Get("/", api.adminUserGet)
r.Put("/", api.adminUserUpdate)
r.Delete("/", api.adminUserDelete)

r.Route("/labels", func(r *router) {
r.Use(api.loadUserLabels)

r.Post("/", api.adminUserLabelsCreateOrUpdate)
})
})
})

Expand Down
15 changes: 15 additions & 0 deletions api/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ const (
netlifyIDKey = contextKey("netlify_id")
externalProviderTypeKey = contextKey("external_provider_type")
userKey = contextKey("user")
labelsKey = contextKey("labels")
externalReferrerKey = contextKey("external_referrer")
functionHooksKey = contextKey("function_hooks")
adminUserKey = contextKey("admin_user")
Expand Down Expand Up @@ -126,6 +127,20 @@ func getUser(ctx context.Context) *models.User {
return obj.(*models.User)
}

// withLabels adds the user labels to the context.
func withLabels(ctx context.Context, labels map[string]*models.UserLabel) context.Context {
return context.WithValue(ctx, labelsKey, labels)
}

// getLabels reads the user labels from the context.
func getLabels(ctx context.Context) map[string]*models.UserLabel {
obj := ctx.Value(labelsKey)
if obj == nil {
return nil
}
return obj.(map[string]*models.UserLabel)
}

// withSignature adds the provided request ID to the context.
func withSignature(ctx context.Context, id string) context.Context {
return context.WithValue(ctx, signatureKey, id)
Expand Down
5 changes: 5 additions & 0 deletions api/signup.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error {
if err := user.UpdateUserMetaData(tx, params.Data); err != nil {
return internalServerError("Database error updating user").WithInternalError(err)
}

label := models.NewUserLabel(user.ID, params.Provider, "verified")
if terr := tx.Create(label); terr != nil {
return internalServerError("Database error creating user label").WithInternalError(terr)
}
} else {
user, terr = a.signupNewUser(ctx, tx, params)
if terr != nil {
Expand Down
30 changes: 30 additions & 0 deletions conf/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,13 @@ type SecurityConfiguration struct {
RefreshTokenRotationEnabled bool `json:"refresh_token_rotation_enabled" split_words:"true" default:"true"`
}

type userLabel struct {
Level uint `json:"level"`
Labels []string `json:"labels"`
}

type UserLabels []userLabel

// Configuration holds all the per-instance configuration.
type Configuration struct {
SiteURL string `json:"site_url" split_words:"true" required:"true"`
Expand All @@ -198,6 +205,7 @@ type Configuration struct {
Domain string `json:"domain"`
Duration int `json:"duration"`
} `json:"cookies"`
UserLabels UserLabels `json:"user_labels" split_words:"true"`
}

func loadEnvironment(filename string) error {
Expand Down Expand Up @@ -343,6 +351,14 @@ func (config *Configuration) ApplyDefaults() {
config.PasswordMinLength = defaultMinPasswordLength
}

if len(config.UserLabels) == 0 {
config.UserLabels = UserLabels{
{Level: 1, Labels: []string{"email", "phone"}},
{Level: 2, Labels: []string{"profile"}},
{Level: 3, Labels: []string{"documents"}},
}
}

config.JWT.InitializeSigningSecret()
}

Expand Down Expand Up @@ -371,6 +387,20 @@ func (config *Configuration) Scan(src interface{}) error {
return json.Unmarshal(source, &config)
}

func (ul *UserLabels) Decode(value string) error {
raw, err := base64.StdEncoding.DecodeString(value)
if err != nil {
return err
}

err = json.Unmarshal([]byte(raw), &ul)
if err != nil {
return err
}

return nil
}

func (o *OAuthProviderConfiguration) Validate() error {
if !o.Enabled {
return errors.New("Provider is not enabled")
Expand Down
5 changes: 4 additions & 1 deletion example.env
Original file line number Diff line number Diff line change
Expand Up @@ -192,4 +192,7 @@ GOTRUE_WEBHOOK_EVENTS=validate,signup,login

# Cookie config
GOTRUE_COOKIE_KEY: "sb"
GOTRUE_COOKIE_DOMAIN: "localhost"
GOTRUE_COOKIE_DOMAIN: "localhost"

# Labels config
GOTRUE_USER_LABELS=W3sibGV2ZWwiOjEsImxhYmVscyI6WyJlbWFpbCIsICJwaG9uZSJdfSx7ImxldmVsIjoyLCJsYWJlbHMiOlsicHJvZmlsZSJdfSx7ImxldmVsIjozLCJsYWJlbHMiOlsiZG9jdW1lbnQiXX1d
1 change: 1 addition & 0 deletions hack/test.env
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,4 @@ GOTRUE_TRACING_TAGS="env:test"
GOTRUE_SECURITY_CAPTCHA_ENABLED="false"
GOTRUE_SECURITY_CAPTCHA_PROVIDER="hcaptcha"
GOTRUE_SECURITY_CAPTCHA_SECRET="0x0000000000000000000000000000000000000000"
GOTRUE_USER_LABELS=W3sibGV2ZWwiOjEsImxhYmVscyI6WyJlbWFpbCIsICJwaG9uZSJdfSx7ImxldmVsIjoyLCJsYWJlbHMiOlsicHJvZmlsZSJdfSx7ImxldmVsIjozLCJsYWJlbHMiOlsiZG9jdW1lbnQiXX1d
16 changes: 16 additions & 0 deletions migrations/20220720000812_create_labels_table.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
-- adds lables table

CREATE TYPE label_name AS ENUM ('email','phone','profile','document');
CREATE TYPE label_state AS ENUM ('unverified','pending','verified','expired');

CREATE TABLE IF NOT EXISTS auth.labels (
id uuid NOT NULL,
user_id uuid NOT NULL,
label label_name NOT NULL,
state label_state NOT NULL DEFAULT 'unverified',
created_at timestamptz NOT NULL,
updated_at timestamptz NOT NULL,
CONSTRAINT labels_pkey PRIMARY KEY (id),
CONSTRAINT labels_user_id_fkey FOREIGN KEY (user_id) REFERENCES auth.users(id) ON DELETE CASCADE
);
COMMENT ON TABLE auth.labels is 'Auth: Stores labels associated to a user.';
Loading