diff --git a/database.go b/database.go index 6707d46..25ca773 100644 --- a/database.go +++ b/database.go @@ -13,6 +13,25 @@ import ( type Database struct { db *mongo.Database client *mongo.Client + + options databaseOptions +} + +func (db *Database) applyOptions(options ...DatabaseOption) error { + for _, o := range options { + if err := o(&db.options); err != nil { + return err + } + } + return nil +} + +func New(options ...DatabaseOption) (*Database, error) { + db := new(Database) + if err := db.applyOptions(options...); err != nil { + return nil, err + } + return db, nil } func (db *Database) connect(options *options.ClientOptions, dbName string) error { @@ -26,9 +45,9 @@ func (db *Database) connect(options *options.ClientOptions, dbName string) error db.client = client err = db.client.Ping(context.Background(), readpref.Primary()) if err == nil { - log.Print("Connected to MongoDB!") + db.options.Logger().Print("Connected to MongoDB!") } else { - log.Panic("Could not connect to MongoDB! Please check if mongo is running.", err) + db.options.Logger().Panic("Could not connect to MongoDB! Please check if mongo is running.", err) return err } db.db = db.client.Database(dbName) @@ -42,7 +61,7 @@ func (db *Database) Connect(connectionString string, dbName string) error { } func (db *Database) Disconnect() error { - err := db.client.Disconnect(DefaultContext()); + err := db.client.Disconnect(DefaultContext()) db.db = nil return err } diff --git a/database_options.go b/database_options.go new file mode 100644 index 0000000..83d876a --- /dev/null +++ b/database_options.go @@ -0,0 +1,31 @@ +package colt + +import ( + "log" +) + +type Logger interface { + Print(v ...interface{}) + Panic(v ...interface{}) +} + +// WithLogger sets the logger for the database. If none is provided, the default logger is used +func WithLogger(logger Logger) DatabaseOption { + return func(db *databaseOptions) error { + db.logger = logger + return nil + } +} + +type DatabaseOption func(db *databaseOptions) error + +type databaseOptions struct { + logger Logger +} + +func (d *databaseOptions) Logger() Logger { + if d.logger == nil { + return log.Default() + } + return d.logger +} diff --git a/database_options_test.go b/database_options_test.go new file mode 100644 index 0000000..af3c364 --- /dev/null +++ b/database_options_test.go @@ -0,0 +1,49 @@ +package colt + +import ( + "log" + "reflect" + "testing" +) + +var ( + _ Logger = (*testLogger)(nil) +) + +type testLogger struct{} + +func (l *testLogger) Print(v ...interface{}) {} +func (l *testLogger) Panic(v ...interface{}) {} + +func TestWithLogger(t *testing.T) { + t.Run("should set the logger for the database", func(t *testing.T) { + logger := &testLogger{} + db := &Database{options: databaseOptions{}} + db.applyOptions(WithLogger(logger)) + if !reflect.DeepEqual(db.options.logger, logger) { + t.Errorf("WithLogger() = %v, want %v", db.options.logger, logger) + } + }) +} + +func Test_databaseOptions_Logger(t *testing.T) { + t.Run("should return the default logger", func(t *testing.T) { + want := log.Default() + db := &Database{options: databaseOptions{}} + + if got := db.options.Logger(); !reflect.DeepEqual(want, got) { + t.Errorf("WithLogger() = %v, want %v", got, want) + } + }) + + t.Run("should return the provided logger from the database", func(t *testing.T) { + want := &testLogger{} + db := &Database{options: databaseOptions{ + logger: want, + }} + + if got := db.options.Logger(); !reflect.DeepEqual(want, got) { + t.Errorf("WithLogger() = %v, want %v", got, want) + } + }) +}