diff --git a/api/admin.go b/api/admin.go index 8022cda8f..58d47e398 100644 --- a/api/admin.go +++ b/api/admin.go @@ -68,7 +68,7 @@ func (a *API) adminUsers(w http.ResponseWriter, r *http.Request) error { return badRequestError("Bad Pagination Parameters: %v", err) } - sortParams, err := sort(r, map[string]bool{models.CreatedAt: true}, []models.SortField{models.SortField{Name: models.CreatedAt, Dir: models.Descending}}) + sortParams, err := sort(r, map[string]bool{models.CreatedAt: true}, []models.SortField{{Name: models.CreatedAt, Dir: models.Descending}}) if err != nil { return badRequestError("Bad Sort Parameters: %v", err) } @@ -345,3 +345,90 @@ func (a *API) adminUserDelete(w http.ResponseWriter, r *http.Request) error { return sendJSON(w, http.StatusOK, map[string]interface{}{}) } + +type adminUserLabelsParams struct { + Label string `json:"label"` + State string `json:"state"` +} + +func (a *API) loadUserLabels(w http.ResponseWriter, req *http.Request) (context.Context, error) { + ctx := req.Context() + user := getUser(ctx) + + labels, err := models.FindUserLabels(a.db, user.ID) + if err != nil { + return ctx, internalServerError("Error loading user labels").WithInternalError(err) + } + + return withLabels(ctx, labels), nil +} + +func (a *API) getAdminUserLabelsParams(r *http.Request) (*adminUserLabelsParams, error) { + params := &adminUserLabelsParams{} + err := json.NewDecoder(r.Body).Decode(¶ms) + if err != nil { + return nil, badRequestError("Could not decode admin user label params: %v", err) + } + return params, nil +} + +func (a *API) adminUserLabelsCreateOrUpdate(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + user := getUser(ctx) + instanceID := getInstanceID(ctx) + adminUser := getAdminUser(ctx) + existingLabels := getLabels(ctx) + config := getConfig(ctx) + params, err := a.getAdminUserLabelsParams(r) + if err != nil { + return err + } + + // check if requested label is in the list of configured labels + exists := false + for _, level := range config.UserLabels { + for _, label := range level.Labels { + if label == params.Label { + exists = true + break + } + } + } + + if !exists { + return badRequestError("Label '%s' is not defined in the config", params.Label) + } + + // perform update + err = a.db.Transaction(func(tx *storage.Connection) error { + var action models.AuditAction + + if label, ok := existingLabels[params.Label]; ok { + action = models.UserLabelModifiedAction + + if terr := label.UpdateState(tx, params.State); terr != nil { + return terr + } + } else { + action = models.UserLabelCreatedAction + newLabel := models.NewUserLabel(user.ID, params.Label, params.State) + + if terr := tx.Create(newLabel); terr != nil { + return terr + } + + existingLabels[newLabel.Label] = newLabel + } + + if terr := models.NewAuditLogEntry(tx, instanceID, adminUser, action, map[string]interface{}{ + "user_id": user.ID, + "label_name": params.Label, + "label_state": params.State, + }); terr != nil { + return internalServerError("Error recording audit log entry").WithInternalError(terr) + } + return nil + }) + + return err +} diff --git a/api/admin_test.go b/api/admin_test.go index a5de623c0..3839881e9 100644 --- a/api/admin_test.go +++ b/api/admin_test.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "fmt" + "math" "net/http" "net/http/httptest" "testing" @@ -43,7 +44,7 @@ func TestAdmin(t *testing.T) { } func (ts *AdminTestSuite) SetupTest() { - models.TruncateAll(ts.API.db) + _ = models.TruncateAll(ts.API.db) ts.Config.External.Email.Enabled = true ts.token = ts.makeSuperAdmin("") } @@ -54,7 +55,6 @@ func (ts *AdminTestSuite) makeSuperAdmin(email string) string { u.Role = "supabase_admin" - key, err := models.FindMainAsymmetricKeyByUser(ts.API.db, u) require.NoError(ts.T(), err, "Error finding keys") @@ -587,6 +587,8 @@ func (ts *AdminTestSuite) TestAdminUserCreateWithDisabledLogin() { }, } + configBackup := *ts.Config + for _, c := range cases { ts.Run(c.desc, func() { // Initialize user data @@ -603,4 +605,91 @@ func (ts *AdminTestSuite) TestAdminUserCreateWithDisabledLogin() { require.Equal(ts.T(), c.expected, w.Code) }) } + + *ts.Config = configBackup +} + +// TestAdminUserLabelsCreateOrUpdate tests API /admin/user/labels route (POST) +func (ts *AdminTestSuite) TestAdminUserLabelsCreateOrUpdate() { + cases := []struct { + desc string + params map[string]interface{} + expected map[string]interface{} + }{ + { + desc: "Create new label", + params: map[string]interface{}{ + "label": "email", + "state": "pending", + }, + expected: map[string]interface{}{ + "httpStatusCode": http.StatusOK, + "new_labels": 1, + }, + }, + { + desc: "Update existing label", + params: map[string]interface{}{ + "label": "email", + "state": "verified", + }, + expected: map[string]interface{}{ + "httpStatusCode": http.StatusOK, + "new_labels": 0, + }, + }, + { + desc: "Label does not exist", + params: map[string]interface{}{ + "label": "test", + "state": "verified", + }, + expected: map[string]interface{}{ + "httpStatusCode": http.StatusBadRequest, + "new_labels": 0, + }, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.params)) + + u, err := models.NewUser(ts.instanceID, "test1@example.com", "test", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error making new user") + if err := ts.API.db.Create(u); err != nil { + u, err = models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "test1@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err, "Error finding user") + } + + beforeLabels, err := models.FindUserLabels(ts.API.db, u.ID) + require.NoError(ts.T(), err, "Error loading user labels") + + // Setup request + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("/admin/users/%v/labels", u.ID), &buffer) + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + ts.Config.External.Phone.Enabled = true + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), c.expected["httpStatusCode"], w.Code) + + // verify request results + afterLabels, err := models.FindUserLabels(ts.API.db, u.ID) + require.NoError(ts.T(), err, "Error loading user labels") + + labelsDelta := len(beforeLabels) - len(afterLabels) + require.Equal(ts.T(), int(math.Abs(float64(labelsDelta))), + c.expected["new_labels"].(int), + fmt.Sprintf("Expected %d label", c.expected["new_labels"])) + + if c.expected["new_labels"].(int) > 0 { + labelName := c.params["label"].(string) + require.Equal(ts.T(), afterLabels[labelName].Label, labelName) + require.Equal(ts.T(), afterLabels[labelName].State, c.params["state"].(string)) + } + }) + } } diff --git a/api/api.go b/api/api.go index a4169c275..2b33d09fa 100644 --- a/api/api.go +++ b/api/api.go @@ -51,7 +51,9 @@ func (a *API) ListenAndServe(hostAndPort string) { waitForTermination(log, done) ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - server.Shutdown(ctx) + if err := server.Shutdown(ctx); err != nil { + log.WithError(err).Fatal("http server shutdown failed") + } }() if err := server.ListenAndServe(); err != http.ErrServerClosed { @@ -166,6 +168,12 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati r.Get("/", api.adminUserGet) r.Put("/", api.adminUserUpdate) r.Delete("/", api.adminUserDelete) + + r.Route("/labels", func(r *router) { + r.Use(api.loadUserLabels) + + r.Post("/", api.adminUserLabelsCreateOrUpdate) + }) }) }) diff --git a/api/context.go b/api/context.go index f1c4a78c1..c822ef0f2 100644 --- a/api/context.go +++ b/api/context.go @@ -26,6 +26,7 @@ const ( netlifyIDKey = contextKey("netlify_id") externalProviderTypeKey = contextKey("external_provider_type") userKey = contextKey("user") + labelsKey = contextKey("labels") externalReferrerKey = contextKey("external_referrer") functionHooksKey = contextKey("function_hooks") adminUserKey = contextKey("admin_user") @@ -126,6 +127,20 @@ func getUser(ctx context.Context) *models.User { return obj.(*models.User) } +// withLabels adds the user labels to the context. +func withLabels(ctx context.Context, labels map[string]*models.UserLabel) context.Context { + return context.WithValue(ctx, labelsKey, labels) +} + +// getLabels reads the user labels from the context. +func getLabels(ctx context.Context) map[string]*models.UserLabel { + obj := ctx.Value(labelsKey) + if obj == nil { + return nil + } + return obj.(map[string]*models.UserLabel) +} + // withSignature adds the provided request ID to the context. func withSignature(ctx context.Context, id string) context.Context { return context.WithValue(ctx, signatureKey, id) diff --git a/api/signup.go b/api/signup.go index 7fcb0bb91..14b0c6e31 100644 --- a/api/signup.go +++ b/api/signup.go @@ -98,6 +98,11 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { if err := user.UpdateUserMetaData(tx, params.Data); err != nil { return internalServerError("Database error updating user").WithInternalError(err) } + + label := models.NewUserLabel(user.ID, params.Provider, "verified") + if terr := tx.Create(label); terr != nil { + return internalServerError("Database error creating user label").WithInternalError(terr) + } } else { user, terr = a.signupNewUser(ctx, tx, params) if terr != nil { diff --git a/conf/configuration.go b/conf/configuration.go index 72ecf2794..7aaef7746 100644 --- a/conf/configuration.go +++ b/conf/configuration.go @@ -179,6 +179,13 @@ type SecurityConfiguration struct { RefreshTokenRotationEnabled bool `json:"refresh_token_rotation_enabled" split_words:"true" default:"true"` } +type userLabel struct { + Level uint `json:"level"` + Labels []string `json:"labels"` +} + +type UserLabels []userLabel + // Configuration holds all the per-instance configuration. type Configuration struct { SiteURL string `json:"site_url" split_words:"true" required:"true"` @@ -198,6 +205,7 @@ type Configuration struct { Domain string `json:"domain"` Duration int `json:"duration"` } `json:"cookies"` + UserLabels UserLabels `json:"user_labels" split_words:"true"` } func loadEnvironment(filename string) error { @@ -343,6 +351,14 @@ func (config *Configuration) ApplyDefaults() { config.PasswordMinLength = defaultMinPasswordLength } + if len(config.UserLabels) == 0 { + config.UserLabels = UserLabels{ + {Level: 1, Labels: []string{"email", "phone"}}, + {Level: 2, Labels: []string{"profile"}}, + {Level: 3, Labels: []string{"documents"}}, + } + } + config.JWT.InitializeSigningSecret() } @@ -371,6 +387,20 @@ func (config *Configuration) Scan(src interface{}) error { return json.Unmarshal(source, &config) } +func (ul *UserLabels) Decode(value string) error { + raw, err := base64.StdEncoding.DecodeString(value) + if err != nil { + return err + } + + err = json.Unmarshal([]byte(raw), &ul) + if err != nil { + return err + } + + return nil +} + func (o *OAuthProviderConfiguration) Validate() error { if !o.Enabled { return errors.New("Provider is not enabled") diff --git a/example.env b/example.env index ec83beeac..fb37294cd 100644 --- a/example.env +++ b/example.env @@ -192,4 +192,7 @@ GOTRUE_WEBHOOK_EVENTS=validate,signup,login # Cookie config GOTRUE_COOKIE_KEY: "sb" -GOTRUE_COOKIE_DOMAIN: "localhost" \ No newline at end of file +GOTRUE_COOKIE_DOMAIN: "localhost" + +# Labels config +GOTRUE_USER_LABELS=W3sibGV2ZWwiOjEsImxhYmVscyI6WyJlbWFpbCIsICJwaG9uZSJdfSx7ImxldmVsIjoyLCJsYWJlbHMiOlsicHJvZmlsZSJdfSx7ImxldmVsIjozLCJsYWJlbHMiOlsiZG9jdW1lbnQiXX1d diff --git a/hack/test.env b/hack/test.env index 9b5d9b89c..6d9a9fb95 100644 --- a/hack/test.env +++ b/hack/test.env @@ -84,3 +84,4 @@ GOTRUE_TRACING_TAGS="env:test" GOTRUE_SECURITY_CAPTCHA_ENABLED="false" GOTRUE_SECURITY_CAPTCHA_PROVIDER="hcaptcha" GOTRUE_SECURITY_CAPTCHA_SECRET="0x0000000000000000000000000000000000000000" +GOTRUE_USER_LABELS=W3sibGV2ZWwiOjEsImxhYmVscyI6WyJlbWFpbCIsICJwaG9uZSJdfSx7ImxldmVsIjoyLCJsYWJlbHMiOlsicHJvZmlsZSJdfSx7ImxldmVsIjozLCJsYWJlbHMiOlsiZG9jdW1lbnQiXX1d diff --git a/migrations/20220720000812_create_labels_table.up.sql b/migrations/20220720000812_create_labels_table.up.sql new file mode 100644 index 000000000..ee81d7520 --- /dev/null +++ b/migrations/20220720000812_create_labels_table.up.sql @@ -0,0 +1,16 @@ +-- adds lables table + +CREATE TYPE label_name AS ENUM ('email','phone','profile','document'); +CREATE TYPE label_state AS ENUM ('unverified','pending','verified','expired'); + +CREATE TABLE IF NOT EXISTS auth.labels ( + id uuid NOT NULL, + user_id uuid NOT NULL, + label label_name NOT NULL, + state label_state NOT NULL DEFAULT 'unverified', + created_at timestamptz NOT NULL, + updated_at timestamptz NOT NULL, + CONSTRAINT labels_pkey PRIMARY KEY (id), + CONSTRAINT labels_user_id_fkey FOREIGN KEY (user_id) REFERENCES auth.users(id) ON DELETE CASCADE +); +COMMENT ON TABLE auth.labels is 'Auth: Stores labels associated to a user.'; diff --git a/models/asymmetric_key.go b/models/asymmetric_key.go index a310bcba4..1d4cea2e0 100644 --- a/models/asymmetric_key.go +++ b/models/asymmetric_key.go @@ -3,13 +3,14 @@ package models import ( "database/sql" "fmt" + "time" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/crypto" "github.com/gofrs/uuid" "github.com/netlify/gotrue/storage" "github.com/pkg/errors" - "time" ) const challengeTokenExpirationDuration = 30 * time.Minute @@ -140,7 +141,7 @@ func (a *AsymmetricKey) verifyEthKeySignature(rawSignature string) error { return nil } -// verifyKeyAndAlgorithm verifies public key format for specific algorithm. +// VerifyKeyAndAlgorithm verifies public key format for specific algorithm. // If key satisfies conditions, nil is returned func VerifyKeyAndAlgorithm(pubkey, algorithm string) error { var err error diff --git a/models/audit_log_entry.go b/models/audit_log_entry.go index 3aa39b742..500f7d1ec 100644 --- a/models/audit_log_entry.go +++ b/models/audit_log_entry.go @@ -24,6 +24,8 @@ const ( UserRecoveryRequestedAction AuditAction = "user_recovery_requested" UserConfirmationRequestedAction AuditAction = "user_confirmation_requested" UserRepeatedSignUpAction AuditAction = "user_repeated_signup" + UserLabelCreatedAction AuditAction = "user_label_created" + UserLabelModifiedAction AuditAction = "user_label_modified" TokenRevokedAction AuditAction = "token_revoked" TokenRefreshedAction AuditAction = "token_refreshed" @@ -40,12 +42,14 @@ var actionLogTypeMap = map[AuditAction]auditLogType{ UserSignedUpAction: team, UserInvitedAction: team, UserDeletedAction: team, - TokenRevokedAction: token, - TokenRefreshedAction: token, UserModifiedAction: user, UserRecoveryRequestedAction: user, UserConfirmationRequestedAction: user, UserRepeatedSignUpAction: user, + UserLabelCreatedAction: user, + UserLabelModifiedAction: user, + TokenRevokedAction: token, + TokenRefreshedAction: token, } // AuditLogEntry is the database model for audit log entries. diff --git a/models/labels.go b/models/labels.go new file mode 100644 index 000000000..00768cae1 --- /dev/null +++ b/models/labels.go @@ -0,0 +1,101 @@ +package models + +import ( + "time" + + "github.com/gobuffalo/pop/v5" + "github.com/gofrs/uuid" + "github.com/netlify/gotrue/conf" + "github.com/netlify/gotrue/storage" +) + +const ( + UserLevelKey string = "level" + configFile string = "" +) + +type UserLabel struct { + ID uuid.UUID `json:"id" db:"id"` + UserID uuid.UUID `json:"user_id" db:"user_id"` + Label string `json:"label" db:"label"` + State string `json:"state" db:"state"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` +} + +func NewUserLabel(userID uuid.UUID, label string, state string) *UserLabel { + userLabel := &UserLabel{ + UserID: userID, + Label: label, + State: state, + } + return userLabel +} + +func (UserLabel) TableName() string { + tableName := "labels" + return tableName +} + +// AfterSave is invoked afterk the user label is saved to +// the database to recalculate the user level +func (ul *UserLabel) AfterSave(tx *pop.Connection) error { + wrappedTx := &storage.Connection{Connection: tx} + + config, err := conf.LoadConfig(configFile) + if err != nil { + return err + } + + existingLabels, err := FindUserLabels(wrappedTx, ul.UserID) + if err != nil { + return err + } + + user, err := FindUserByID(wrappedTx, ul.UserID) + if err != nil { + return err + } + + newLevel := uint64(0) +levelsLoop: + for _, levelEntry := range config.UserLabels { + for _, label := range levelEntry.Labels { + if _, ok := existingLabels[label]; !ok { + break levelsLoop + } + } + newLevel++ + } + + if terr := user.UpdateUserMetaData(wrappedTx, map[string]interface{}{ + UserLevelKey: newLevel, + }); terr != nil { + return terr + } + return nil +} + +// UpdateState updates the state column of a user label +func (ul *UserLabel) UpdateState(tx *storage.Connection, state string) error { + ul.State = state + return tx.UpdateOnly(ul, "state") +} + +// FindUserLabels finds all user labels matching the provided user ID +func FindUserLabels(tx *storage.Connection, userID uuid.UUID) (map[string]*UserLabel, error) { + var labels []*UserLabel + + q := tx.Q().Where("user_id = ?", userID) + err := q.All(&labels) + if err != nil { + return nil, err + } + + res := make(map[string]*UserLabel) + for _, label := range labels { + res[label.Label] = label + } + + return res, nil +} diff --git a/models/labels_test.go b/models/labels_test.go new file mode 100644 index 000000000..4eddde28e --- /dev/null +++ b/models/labels_test.go @@ -0,0 +1,65 @@ +package models + +import ( + "testing" + + "github.com/gofrs/uuid" + "github.com/netlify/gotrue/conf" + "github.com/netlify/gotrue/storage" + "github.com/netlify/gotrue/storage/test" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type UserLabelsTestSuite struct { + suite.Suite + db *storage.Connection + Config *conf.GlobalConfiguration + user *User + label *UserLabel +} + +func TestUserLabelsTestSuite(t *testing.T) { + globalConfig, err := conf.LoadGlobal(modelsTestConfig) + require.NoError(t, err) + + conn, err := test.SetupDBConnection(globalConfig) + require.NoError(t, err) + + ts := &UserLabelsTestSuite{ + db: conn, + } + defer ts.db.Close() + + suite.Run(t, ts) +} + +func (ts *UserLabelsTestSuite) SetupTest() { + err := TruncateAll(ts.db) + require.NoError(ts.T(), err, "Failed to truncate tables") + + // Create user + u, err := NewUser(uuid.Nil, "test@example.com", "secret", "test", nil) + ts.user = u + + require.NoError(ts.T(), err, "Error creating test user model") + require.NoError(ts.T(), ts.db.Create(u), "Error saving new test user") + + // Create label + l := NewUserLabel(u.ID, "email", "pending") + require.NoError(ts.T(), ts.db.Create(l), "Error saving new test label") + ts.label = l +} + +func (ts *UserLabelsTestSuite) TestFindUserLabels() { + labels, err := FindUserLabels(ts.db, ts.user.ID) + require.NoError(ts.T(), err, "Error finding user labels") + require.Len(ts.T(), labels, 1, "Expected 1 user label") + require.Equal(ts.T(), "email", labels["email"].Label, "Expected user label name to match") + require.Equal(ts.T(), "pending", labels["email"].State, "Expected user label state to match") +} + +func (ts *UserLabelsTestSuite) TestUpdateState() { + err := ts.label.UpdateState(ts.db, "verified") + require.NoError(ts.T(), err, "Error updating user label state") +}