diff --git a/doorway/_build/.dockerignore b/doorway/_build/.dockerignore new file mode 100644 index 0000000..b273849 --- /dev/null +++ b/doorway/_build/.dockerignore @@ -0,0 +1,2 @@ +.gitignore +doorway.tgz diff --git a/doorway/_build/.gitignore b/doorway/_build/.gitignore new file mode 100644 index 0000000..8a8f34a --- /dev/null +++ b/doorway/_build/.gitignore @@ -0,0 +1,2 @@ +/doorway +/doorway.tgz diff --git a/doorway/_build/Dockerfile b/doorway/_build/Dockerfile new file mode 100644 index 0000000..d011a43 --- /dev/null +++ b/doorway/_build/Dockerfile @@ -0,0 +1,27 @@ +# syntax=docker/dockerfile:1.3-labs + +FROM ubuntu:24.04 + +RUN < + + + + {{.Title}} + + + \ No newline at end of file diff --git a/reefd/auth_gate.go b/reefd/auth_gate.go new file mode 100644 index 0000000..1c48407 --- /dev/null +++ b/reefd/auth_gate.go @@ -0,0 +1,201 @@ +package reefd + +import ( + "context" + "crypto/rand" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "golang.org/x/crypto/ssh" + + "github.com/ray-project/rayci/reefd/reefapi" +) + +type authGate struct { + sessionStore *sessionStore + + rand io.Reader + + nowFunc func() time.Time + userKeys map[string]string + + unauth map[string]struct{} +} + +func newAuthGate( + sessions *sessionStore, userKeys map[string]string, + unauth []string, +) *authGate { + unauthMap := make(map[string]struct{}, len(unauth)) + for _, u := range unauth { + unauthMap[u] = struct{}{} + } + + return &authGate{ + sessionStore: sessions, + + rand: rand.Reader, + nowFunc: time.Now, + userKeys: userKeys, + unauth: unauthMap, + } +} + +const sessionTokenPrefix = "ses_" + +func (g *authGate) newSessionToken( + ctx context.Context, req *reefapi.TokenRequest, +) (string, error) { + const tokenSize = 24 + rand := make([]byte, tokenSize) + if _, err := io.ReadFull(g.rand, rand); err != nil { + return "", fmt.Errorf("generate random token: %w", err) + } + + // Encode the token in base64 to make it URL safe. + token := sessionTokenPrefix + base64.RawURLEncoding.EncodeToString(rand) + now := g.nowFunc() + const ttl = 10 * time.Hour + + session := &session{ + user: req.User, + token: token, + expire: now.Add(ttl), + } + + if err := g.sessionStore.insert(ctx, session); err != nil { + return "", fmt.Errorf("save session: %w", err) + } + + return token, nil +} + +func (g *authGate) apiLogin(ctx context.Context, req *reefapi.LoginRequest) ( + *reefapi.LoginResponse, error, +) { + if req.User == "" { + return nil, fmt.Errorf("user is empty") + } + + keyBytes, ok := g.userKeys[req.User] + if !ok { + return nil, fmt.Errorf("user %q not found", req.User) + } + + pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(keyBytes)) + if err != nil { + return nil, fmt.Errorf( + "parse public key of user %q: %w", req.User, err, + ) + } + + if req.SigningMethod != "ssh-ed25519" { + return nil, fmt.Errorf( + "unsupported signing method %q", + req.SigningMethod, + ) + } + + sig := new(reefapi.SSHSignature) + if err := json.Unmarshal(req.Signature, sig); err != nil { + return nil, fmt.Errorf("unmarshal signature: %w", err) + } + + sshSig := &ssh.Signature{ + Format: sig.Format, + Blob: sig.Blob, + Rest: sig.Rest, + } + + if err := pubKey.Verify(req.TokenRequest, sshSig); err != nil { + return nil, fmt.Errorf("verify signature: %w", err) + } + + tokenReq := new(reefapi.TokenRequest) + if err := json.Unmarshal(req.TokenRequest, tokenReq); err != nil { + return nil, fmt.Errorf("unmarshal token request: %w", err) + } + + if tokenReq.User != req.User { + return nil, fmt.Errorf( + "user mismatch: %q != %q", + tokenReq.User, req.User, + ) + } + + sessionToken, err := g.newSessionToken(ctx, tokenReq) + if err != nil { + return nil, fmt.Errorf("new session token: %w", err) + } + + resp := &reefapi.LoginResponse{SessionToken: sessionToken} + return resp, nil +} + +func (g *authGate) check(ctx context.Context, token string) (string, error) { + ses, err := g.sessionStore.get(ctx, token) + if err != nil { + return "", fmt.Errorf("get session: %w", err) + } + if g.nowFunc().After(ses.expire) { + return "", fmt.Errorf("session %q has expired", token) + } + return ses.user, nil +} + +func (g *authGate) gate(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if _, ok := g.unauth[r.URL.Path]; ok { + // unauth endpoints are not protected. + h.ServeHTTP(w, r) + return + } + + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + http.Error( + w, "authorization header is empty", + http.StatusUnauthorized, + ) + return + } + + const bearerPrefix = "Bearer " + + if !strings.HasPrefix(authHeader, bearerPrefix) { + http.Error( + w, "authorization header must be Bearer", + http.StatusUnauthorized, + ) + return + } + + ctx := r.Context() + token := strings.TrimPrefix(authHeader, bearerPrefix) + user, err := g.check(ctx, token) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + + ctx = contextWithUser(ctx, user) + h.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +func (g *authGate) apiLogout(ctx context.Context, req *reefapi.LogoutRequest) ( + *reefapi.LogoutResponse, error, +) { + if req.SessionToken == "" { + return nil, fmt.Errorf("session token is empty") + } + if err := g.sessionStore.delete(ctx, req.SessionToken); err != nil { + return nil, fmt.Errorf("delete session: %w", err) + } + return &reefapi.LogoutResponse{}, nil +} diff --git a/reefd/build.sh b/reefd/build.sh index 96009c9..a50ba3b 100644 --- a/reefd/build.sh +++ b/reefd/build.sh @@ -1,6 +1,7 @@ #!/bin/bash set -euo pipefail +set -x CGO_ENABLED=0 go build -trimpath -o _build/reefd ./reefd ( diff --git a/reefd/build_dev.sh b/reefd/build_dev.sh new file mode 100644 index 0000000..88ec3c1 --- /dev/null +++ b/reefd/build_dev.sh @@ -0,0 +1,7 @@ +#!/bin/bassh + +set -euo pipefail +set -ex + +CGO_ENABLED=0 go build -o _dev/reefd ./reefd +CGO_ENABLED=0 go build -o _dev/reefy ./reefy diff --git a/reefd/context.go b/reefd/context.go new file mode 100644 index 0000000..75ff3d1 --- /dev/null +++ b/reefd/context.go @@ -0,0 +1,18 @@ +package reefd + +import "context" + +type contextKey int // userKey is the key used to store the user in the context. + +const ( + userKey contextKey = iota +) + +func contextWithUser(ctx context.Context, user string) context.Context { + return context.WithValue(ctx, userKey, user) +} + +func userFromContext(ctx context.Context) (string, bool) { + user, ok := ctx.Value(userKey).(string) + return user, ok +} diff --git a/reefd/context_test.go b/reefd/context_test.go new file mode 100644 index 0000000..fc81a91 --- /dev/null +++ b/reefd/context_test.go @@ -0,0 +1,26 @@ +package reefd + +import ( + "context" + "testing" +) + +func TestContextWithUser(t *testing.T) { + ctx := context.Background() + empty, ok := userFromContext(ctx) + if ok { + t.Errorf("got ok %v, want false", ok) + } + if empty != "" { + t.Errorf("got user %q, want empty", empty) + } + + ctx = contextWithUser(ctx, "testuser") + got, ok := userFromContext(ctx) + if !ok { + t.Errorf("got ok %v, want true", ok) + } + if got != "testuser" { + t.Errorf("got user %q, want %q", got, "testuser") + } +} diff --git a/reefd/reaper.go b/reefd/reaper.go index 1d91040..b2390d3 100644 --- a/reefd/reaper.go +++ b/reefd/reaper.go @@ -8,6 +8,7 @@ import ( "time" "github.com/aws/aws-sdk-go-v2/aws" + awsconfig "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" ) @@ -100,3 +101,20 @@ func (r *reaper) listAndReapDeadWindowsInstances(ctx context.Context) ( return len(ids), nil } + +// ReapDeadWindowsInstances lists and terminates dead Windows CI instances. +func ReapDeadWindowsInstances(ctx context.Context) error { + awsConfig, err := awsconfig.LoadDefaultConfig( + ctx, awsconfig.WithRegion(awsRegion), + ) + if err != nil { + return fmt.Errorf("load aws config: %w", err) + } + + clients := newAWSClientsFromConfig(&awsConfig) + r := newReaper(clients.ec2()) + if _, err := r.listAndReapDeadWindowsInstances(ctx); err != nil { + return err + } + return nil +} diff --git a/reefd/reefapi/login.go b/reefd/reefapi/login.go new file mode 100644 index 0000000..4bce1b5 --- /dev/null +++ b/reefd/reefapi/login.go @@ -0,0 +1,46 @@ +package reefapi + +// TokenRequest is the encoded request to sign in with a user name. +type TokenRequest struct { + // User is the user name to sign in with. + User string `json:"user"` +} + +// SSHSignature is the signature structure of the verifying an token request +// that is signed with an SSH key. +type SSHSignature struct { + Format string `json:"format"` + Blob []byte `json:"blob"` + Rest []byte `json:"rest,omitempty"` +} + +// LoginRequest is the request to sign in as an user. +type LoginRequest struct { + // User is the user name to sign in with. + // It hints on which user and key to use to verify the token request. + User string `json:"user"` + + // TokenRequest is the JSON encoded token request. + TokenRequest []byte `json:"token_request"` + + // SigningMethod is the method used to sign the token request. + SigningMethod string `json:"signing_method"` + + // Signature is the cryptographic signature of the token request. + Signature []byte `json:"signature"` +} + +// LoginResponse is the response to a successful login request. +// It contains the session token to use for subsequent requests. +type LoginResponse struct { + SessionToken string `json:"session_token"` +} + +// LogoutRequest is the request to log out of a session. +type LogoutRequest struct { + // SessionToken is the session token to log out. + SessionToken string `json:"session_token"` +} + +// LogoutResponse is the response to a successful logout request. +type LogoutResponse struct{} diff --git a/reefd/reefclient/login.go b/reefd/reefclient/login.go new file mode 100644 index 0000000..140053d --- /dev/null +++ b/reefd/reefclient/login.go @@ -0,0 +1,124 @@ +package reefclient + +import ( + "context" + "crypto/rand" + "encoding/json" + "fmt" + "log" + "os" + + "github.com/ray-project/rayci/reefd/reefapi" + "golang.org/x/crypto/ssh" +) + +type client struct { + caller *JSONCaller +} + +func newClient(server string) (*client, error) { + caller, err := NewJSONCaller(server) + if err != nil { + return nil, fmt.Errorf("new caller: %w", err) + } + return &client{caller: caller}, nil +} + +func (c *client) callLogin(ctx context.Context, req *reefapi.LoginRequest) ( + *reefapi.LoginResponse, error, +) { + resp := &reefapi.LoginResponse{} + if err := JSONCall(ctx, c.caller, "api/v1/login", req, resp); err != nil { + return nil, err + } + return resp, nil +} + +func (c *client) callLogout(ctx context.Context, req *reefapi.LogoutRequest) ( + *reefapi.LogoutResponse, error, +) { + resp := &reefapi.LogoutResponse{} + if err := JSONCall(ctx, c.caller, "api/v1/logout", req, resp); err != nil { + return nil, err + } + return resp, nil +} + +func (c *client) login(ctx context.Context, user string) (string, error) { + tokenReq := &reefapi.TokenRequest{User: user} + + tokenReqBytes, err := json.Marshal(tokenReq) + if err != nil { + return "", fmt.Errorf("encode token request: %w", err) + } + privateKeyFile := os.ExpandEnv("$HOME/.ssh/id_ed25519") + privateKeyBytes, err := os.ReadFile(privateKeyFile) + if err != nil { + return "", fmt.Errorf("read private key: %w", err) + } + + priKey, err := ssh.ParsePrivateKey(privateKeyBytes) + if err != nil { + return "", fmt.Errorf("parse private key: %w", err) + } + + sshSig, err := priKey.Sign(rand.Reader, tokenReqBytes) + if err != nil { + return "", fmt.Errorf("sign token request: %w", err) + } + sigBytes, err := json.Marshal(&reefapi.SSHSignature{ + Format: sshSig.Format, + Blob: sshSig.Blob, + Rest: sshSig.Rest, + }) + if err != nil { + return "", fmt.Errorf("encode signature: %w", err) + } + + resp, err := c.callLogin(ctx, &reefapi.LoginRequest{ + User: user, + TokenRequest: tokenReqBytes, + SigningMethod: "ssh-ed25519", + Signature: sigBytes, + }) + if err != nil { + return "", fmt.Errorf("login: %w", err) + } + + return resp.SessionToken, nil +} + +func (c *client) logout(ctx context.Context, sessionToken string) error { + logoutReq := &reefapi.LogoutRequest{ + SessionToken: sessionToken, + } + if _, err := c.callLogout(ctx, logoutReq); err != nil { + return fmt.Errorf("logout: %w", err) + } + return nil +} + +// Main is the main function that runs the client. +func Main(ctx context.Context) error { + const server = "http://localhost:8000" + + client, err := newClient(server) + if err != nil { + return fmt.Errorf("new client: %w", err) + } + + const user = "aslonnie" + + tok, err := client.login(ctx, user) + if err != nil { + return fmt.Errorf("login: %w", err) + } + log.Printf("session token: %q", tok) + + if err := client.logout(ctx, tok); err != nil { + return fmt.Errorf("logout: %w", err) + } + log.Println("successfully logout") + + return nil +} diff --git a/reefd/reefd/main.go b/reefd/reefd/main.go index 1180593..4d2b397 100644 --- a/reefd/reefd/main.go +++ b/reefd/reefd/main.go @@ -1,19 +1,70 @@ package main import ( + "context" + "encoding/json" "flag" + "fmt" "log" + "os" "github.com/ray-project/rayci/reefd" + "github.com/tailscale/hujson" ) -func main() { +func readConfig(path string) (*reefd.Config, error) { + bs, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read config file: %w", err) + } + bs, err = hujson.Standardize(bs) + if err != nil { + return nil, fmt.Errorf("standardize config file: %w", err) + } + config := &reefd.Config{} + if err := json.Unmarshal(bs, config); err != nil { + return nil, fmt.Errorf("unmarshal config file: %w", err) + } + return config, nil +} + +func main() { addr := flag.String("addr", "localhost:8000", "address to listen on") + configFile := flag.String( + "config", "config.hujson", "path to the config file", + ) flag.Parse() - log.Printf("serving at %s", *addr) - if err := reefd.Serve(*addr, config); err != nil { - log.Fatal(err) + args := flag.Args() + if len(args) > 1 { + log.Fatal("Usage: reefd takes 0 or 1 argument") + } + + config, err := readConfig(*configFile) + if err != nil { + log.Fatalf("read config file: %v", err) + } + + if len(args) == 0 { + // serving by default. + log.Printf("serving at %s", *addr) + ctx := context.Background() + if err := reefd.Serve(ctx, *addr, config); err != nil { + log.Fatal(err) + } + return + } + + cmd := args[0] + switch cmd { + case "reap-windows": + log.Printf("reaping windows") + ctx := context.Background() + if err := reefd.ReapDeadWindowsInstances(ctx); err != nil { + log.Fatal(err) + } + default: + log.Fatalf("unknown command %q", cmd) } } diff --git a/reefd/reefy/main.go b/reefd/reefy/main.go new file mode 100644 index 0000000..f49cfe6 --- /dev/null +++ b/reefd/reefy/main.go @@ -0,0 +1,14 @@ +package main + +import ( + "context" + "log" + + "github.com/ray-project/rayci/reefd/reefclient" +) + +func main() { + if err := reefclient.Main(context.Background()); err != nil { + log.Fatal(err) + } +} diff --git a/reefd/serve.go b/reefd/serve.go index 30c27c4..026fef8 100644 --- a/reefd/serve.go +++ b/reefd/serve.go @@ -2,32 +2,138 @@ package reefd import ( + "context" + "fmt" "io" + "log" "net/http" + "strings" + "time" ) // Config contains the configuration for the running the server. type Config struct { + Database string + + DisableBackground bool + + UserKeys map[string]string } type server struct { + db *database + + stores []store + + reaper *reaper config *Config + + apiV1 http.Handler } -func newServer(c *Config) *server { - return &server{config: c} +func newServer(ctx context.Context, config *Config) (*server, error) { + db, err := newSqliteDB(config.Database) + if err != nil { + return nil, fmt.Errorf("new sqlite db: %w", err) + } + + awsClients, err := newAWSClients(ctx) + if err != nil { + return nil, fmt.Errorf("new aws clients: %w", err) + } + reaper := newReaper(awsClients.ec2()) + + sessionStore := newSessionStore(db) + authGate := newAuthGate( + sessionStore, + config.UserKeys, + []string{ + "/api/v1/login", + "/api/v1/logout", + }, // unauthenticated endpoints + ) + + apiMux := http.NewServeMux() + apiMux.Handle("/api/v1/login", jsonAPI(authGate.apiLogin)) + apiMux.Handle("/api/v1/logout", jsonAPI(authGate.apiLogout)) + + stores := []store{sessionStore} + + return &server{ + db: db, + stores: stores, + reaper: reaper, + config: config, + apiV1: authGate.gate(apiMux), + }, nil +} + +func (s *server) initStorage(ctx context.Context) error { + return createAll(ctx, s.stores) } func (s *server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - io.WriteString(w, "Hello, World!") + if strings.HasPrefix(r.URL.Path, "/api/v1/") { + s.apiV1.ServeHTTP(w, r) + return + } + + io.WriteString(w, "Ray CI") +} + +func (s *server) listAndReapDeadWindowsInstances(ctx context.Context) error { + for { + n, err := s.reaper.listAndReapDeadWindowsInstances(ctx) + if err != nil { + return err + } + if n == 0 { + return nil + } + + time.Sleep(5 * time.Second) + } } +func (s *server) background(ctx context.Context) { + log.Println("background process started") + + if err := s.listAndReapDeadWindowsInstances(ctx); err != nil { + log.Println("listAndReapDeadWindowsInstances: ", err) + } + + const period = 20 * time.Minute + ticker := time.NewTicker(period) + defer ticker.Stop() + + for range ticker.C { + if err := s.listAndReapDeadWindowsInstances(ctx); err != nil { + log.Println("listAndReapDeadWindowsInstances: ", err) + } + } +} + +func (s *server) Close() error { return nil } + // Serve runs the server. -func Serve(addr string, c *Config) error { - s := newServer(c) +func Serve(ctx context.Context, addr string, config *Config) error { + s, err := newServer(ctx, config) + if err != nil { + return fmt.Errorf("new server: %w", err) + } + if err := s.initStorage(ctx); err != nil { + return fmt.Errorf("init storage: %w", err) + } + httpServer := &http.Server{ Addr: addr, Handler: s, } + + if !config.DisableBackground { + go s.background(ctx) + } + + defer s.Close() return httpServer.ListenAndServe() } diff --git a/reefd/serve_test.go b/reefd/serve_test.go index f44559e..82b218a 100644 --- a/reefd/serve_test.go +++ b/reefd/serve_test.go @@ -1,6 +1,7 @@ package reefd import ( + "context" "io" "net/http" "net/http/httptest" @@ -8,12 +9,19 @@ import ( ) func TestServer(t *testing.T) { - s := httptest.NewServer(newServer(&Config{})) - defer s.Close() + ctx := context.Background() + s, err := newServer(ctx, &Config{}) + if err != nil { + t.Fatal("new server: ", err) + } + + httpServer := httptest.NewServer(s) + defer httpServer.Close() - resp, err := s.Client().Get(s.URL) + client := httpServer.Client() + resp, err := client.Get(httpServer.URL) if err != nil { - t.Fatal(err) + t.Fatal("get url: ", err) } defer resp.Body.Close() @@ -21,7 +29,7 @@ func TestServer(t *testing.T) { t.Errorf("got status code %d, want %d", resp.StatusCode, http.StatusOK) } - want := "Hello, World!" + want := "Ray CI" got, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("read response: %v", err) diff --git a/reefd/session_store.go b/reefd/session_store.go new file mode 100644 index 0000000..3312808 --- /dev/null +++ b/reefd/session_store.go @@ -0,0 +1,81 @@ +package reefd + +import ( + "context" + "database/sql" + "errors" + "time" +) + +type session struct { + user string + token string + expire time.Time +} + +type sessionStore struct { + db *database +} + +func newSessionStore(db *database) *sessionStore { + return &sessionStore{db: db} +} + +// create the session tables if it doesn't exist. +func (s *sessionStore) create(ctx context.Context) error { + _, err := s.db.X(ctx, ` + CREATE TABLE IF NOT EXISTS sessions ( + token text NOT NULL PRIMARY KEY, + user text NOT NULL, + expire integer NOT NULL)`, + ) + return err +} + +// destroy the session tables. +func (s *sessionStore) destroy(ctx context.Context) error { + _, err := s.db.X(ctx, `drop table if exists sessions`) + return err +} + +// insert a session into the database. +func (s *sessionStore) insert( + ctx context.Context, session *session, +) error { + _, err := s.db.X( + ctx, `INSERT INTO sessions (token, user, expire) VALUES (?, ?, ?)`, + session.token, session.user, session.expire.Unix(), + ) + return err +} + +var errSessionNotFound = errors.New("session not found") + +// get a session from the database by token. +func (s *sessionStore) get(ctx context.Context, token string) (*session, error) { + var user string + var expire int64 + + if err := s.db.Q1( + ctx, `SELECT user, expire FROM sessions WHERE token = ?`, token, + ).Scan(&user, &expire); err != nil { + if err == sql.ErrNoRows { + return nil, errSessionNotFound + } + return nil, err + } + + return &session{user: user, token: token, expire: time.Unix(expire, 0)}, nil +} + +// delete a session from the database by token. +func (s *sessionStore) delete(ctx context.Context, token string) error { + _, err := s.db.X(ctx, `DELETE FROM sessions WHERE token = ?`, token) + return err +} + +// delete expired sessions from the database. +func (s *sessionStore) deleteExpired(ctx context.Context, t time.Time) error { + _, err := s.db.X(ctx, `DELETE FROM sessions WHERE expire < ?`, t.Unix()) + return err +} diff --git a/reefd/session_store_test.go b/reefd/session_store_test.go new file mode 100644 index 0000000..fffe182 --- /dev/null +++ b/reefd/session_store_test.go @@ -0,0 +1,130 @@ +package reefd + +import ( + "errors" + "testing" + "time" + + "context" +) + +func TestSessionStore_lifecycle(t *testing.T) { + ctx := context.Background() + db, err := newSqliteDB("") + if err != nil { + t.Fatalf("new sqlite db: %v", err) + } + + s := newSessionStore(db) + + if err := s.create(ctx); err != nil { + t.Fatalf("create: %v", err) + } + if err := s.destroy(ctx); err != nil { + t.Fatalf("destroy: %v", err) + } +} + +func TestSessionStore_insert(t *testing.T) { + ctx := context.Background() + db, err := newSqliteDB("") + if err != nil { + t.Fatalf("new sqlite db: %v", err) + } + + s := newSessionStore(db) + if err := s.create(ctx); err != nil { + t.Fatalf("create: %v", err) + } + + now := time.Now().Truncate(time.Second) + session := &session{ + user: "testuser", + token: "testtoken", + expire: now.Add(time.Hour * 24), + } + + if err := s.insert(ctx, session); err != nil { + t.Fatalf("insert: %v", err) + } + + got, err := s.get(ctx, session.token) + if err != nil { + t.Fatalf("get session: %v", err) + } + + if got.user != session.user { + t.Errorf("got user %q, want %q", got.user, session.user) + } + if got.expire != session.expire { + t.Errorf("got expire %q, want %q", got.expire, session.expire) + } + if got.token != session.token { + t.Errorf("got token %q, want %q", got.token, session.token) + } + + if err := s.delete(ctx, session.token); err != nil { + t.Fatalf("delete: %v", err) + } + + got, err = s.get(ctx, session.token) + if err == nil { + t.Errorf("got nil error after delete, want %v", errSessionNotFound) + } else if !errors.Is(err, errSessionNotFound) { + t.Fatalf("got error %q, want %q", err, errSessionNotFound) + } +} + +func TestSessionStore_deleteExpired(t *testing.T) { + ctx := context.Background() + db, err := newSqliteDB("") + if err != nil { + t.Fatalf("new sqlite db: %v", err) + } + + s := newSessionStore(db) + if err := s.create(ctx); err != nil { + t.Fatalf("create: %v", err) + } + + now := time.Now().Truncate(time.Second) + + session1 := &session{ + user: "testuser", + token: "testtoken1", + expire: now.Add(1 * time.Hour), + } + if err := s.insert(ctx, session1); err != nil { + t.Fatalf("insert: %v", err) + } + + session2 := &session{ + user: "testuser", + token: "testtoken2", + expire: now.Add(3 * time.Hour), + } + if err := s.insert(ctx, session2); err != nil { + t.Fatalf("insert: %v", err) + } + + if err := s.deleteExpired(ctx, now.Add(2*time.Hour)); err != nil { + t.Fatalf("delete expired: %v", err) + } + + got, err := s.get(ctx, session1.token) + if !errors.Is(err, errSessionNotFound) { + if err != nil { + t.Errorf("got error %q, want %q", err, errSessionNotFound) + } else { + t.Errorf("got session %+v, want nil", got) + } + } + + got, err = s.get(ctx, session2.token) + if err != nil { + t.Fatalf("get session: %v", err) + } + if got.user != session2.user { + t.Errorf("got user %q, want %q", got.user, session2.user) + } +}