diff --git a/Dockerfile b/Dockerfile index 0710d0f..f75f996 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1.23-alpine AS builder +FROM golang:1.26-alpine AS builder ENV GOOS=linux ENV GOARCH=amd64 ENV CGO_ENABLED=0 diff --git a/Dockerfile.debug b/Dockerfile.debug new file mode 100644 index 0000000..dc2b585 --- /dev/null +++ b/Dockerfile.debug @@ -0,0 +1,24 @@ +FROM golang:1.26-alpine AS builder +ENV GOOS=linux +ENV GOARCH=amd64 +ENV CGO_ENABLED=0 + +WORKDIR /app + +COPY . . +COPY ./ssh-sync .ssh-sync + +RUN go mod download +RUN go mod verify + +RUN go test ./... -cover +RUN go build -o /godocker + +FROM scratch + +WORKDIR / + +COPY --from=builder /godocker /godocker + +ENV NO_DOTENV=1 +ENTRYPOINT ["/godocker"] \ No newline at end of file diff --git a/docker-compose.yaml b/docker-compose.yaml index 449a350..dfba8f7 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -18,7 +18,7 @@ services: image: ssh-sync-server-prerelease container_name: ssh-sync-server ssh-sync-db: - image: therealpaulgg/ssh-sync-db:latest + image: ssh-sync-db container_name: ssh-sync-db-debug environment: - POSTGRES_USER=sshsync diff --git a/go.mod b/go.mod index b156136..2b05ca4 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,6 @@ module github.com/therealpaulgg/ssh-sync-server -go 1.23.0 - -toolchain go1.24.7 +go 1.26.0 require ( github.com/go-chi/chi v1.5.4 @@ -15,9 +13,9 @@ require ( github.com/georgysavva/scany/v2 v2.0.0 github.com/gobwas/httphead v0.1.0 // indirect github.com/gobwas/pool v0.2.1 // indirect - github.com/gobwas/ws v1.1.0 + github.com/gobwas/ws v1.4.0 github.com/goccy/go-json v0.10.2 // indirect - github.com/google/uuid v1.3.0 + github.com/google/uuid v1.6.0 github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/pgx/v5 v5.5.4 @@ -34,7 +32,6 @@ require ( github.com/samber/do v1.5.1 github.com/samber/lo v1.37.0 github.com/sethvargo/go-diceware v0.3.0 - github.com/therealpaulgg/ssh-sync v1.2.2 golang.org/x/crypto v0.35.0 // indirect golang.org/x/exp v0.0.0-20230111222715-75897c7a292a // indirect golang.org/x/sync v0.11.0 // indirect @@ -47,29 +44,14 @@ require ( ) require ( - github.com/atotto/clipboard v0.1.4 // indirect - github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect - github.com/charmbracelet/bubbles v0.20.0 // indirect - github.com/charmbracelet/bubbletea v1.1.0 // indirect - github.com/charmbracelet/lipgloss v0.13.1 // indirect - github.com/charmbracelet/x/ansi v0.3.2 // indirect - github.com/charmbracelet/x/term v0.2.0 // indirect - github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect + filippo.io/mldsa v0.0.0-20260215214346-43d0283efc3e + github.com/therealpaulgg/ssh-sync-common v0.0.1 +) + +require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect github.com/kr/text v0.2.0 // indirect - github.com/lucasb-eyer/go-colorful v1.2.0 // indirect - github.com/mattn/go-localereader v0.0.1 // indirect - github.com/mattn/go-runewidth v0.0.16 // indirect - github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect - github.com/muesli/cancelreader v0.2.2 // indirect - github.com/muesli/termenv v0.15.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/rivo/uniseg v0.4.7 // indirect - github.com/russross/blackfriday/v2 v2.1.0 // indirect - github.com/sahilm/fuzzy v0.1.1 // indirect github.com/segmentio/asm v1.2.0 // indirect - github.com/urfave/cli/v2 v2.23.7 // indirect - github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 4f12175..7099c7d 100644 --- a/go.sum +++ b/go.sum @@ -1,30 +1,14 @@ -github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= -github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= -github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= -github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= -github.com/charmbracelet/bubbles v0.20.0 h1:jSZu6qD8cRQ6k9OMfR1WlM+ruM8fkPWkHvQWD9LIutE= -github.com/charmbracelet/bubbles v0.20.0/go.mod h1:39slydyswPy+uVOHZ5x/GjwVAFkCsV8IIVy+4MhzwwU= -github.com/charmbracelet/bubbletea v1.1.0 h1:FjAl9eAL3HBCHenhz/ZPjkKdScmaS5SK69JAK2YJK9c= -github.com/charmbracelet/bubbletea v1.1.0/go.mod h1:9Ogk0HrdbHolIKHdjfFpyXJmiCzGwy+FesYkZr7hYU4= -github.com/charmbracelet/lipgloss v0.13.1 h1:Oik/oqDTMVA01GetT4JdEC033dNzWoQHdWnHnQmXE2A= -github.com/charmbracelet/lipgloss v0.13.1/go.mod h1:zaYVJ2xKSKEnTEEbX6uAHabh2d975RJ+0yfkFpRBz5U= -github.com/charmbracelet/x/ansi v0.3.2 h1:wsEwgAN+C9U06l9dCVMX0/L3x7ptvY1qmjMwyfE6USY= -github.com/charmbracelet/x/ansi v0.3.2/go.mod h1:dk73KoMTT5AX5BsX0KrqhsTqAnhZZoCBjs7dGWp4Ktw= -github.com/charmbracelet/x/term v0.2.0 h1:cNB9Ot9q8I711MyZ7myUR5HFWL/lc3OpU8jZ4hwm0x0= -github.com/charmbracelet/x/term v0.2.0/go.mod h1:GVxgxAbjUrmpvIINHIQnJJKpMlHiZ4cktEQCN6GWyF0= +filippo.io/mldsa v0.0.0-20260215214346-43d0283efc3e h1:VsUbObBMxXlc23Eb9VeeJYE4jvTs87qa5RqSN2U5FJU= +filippo.io/mldsa v0.0.0-20260215214346-43d0283efc3e/go.mod h1:32qQ5yj3R24Eu03iWFWchdC3OB653wPvoepWejkefbY= github.com/cockroachdb/cockroach-go/v2 v2.2.0 h1:/5znzg5n373N/3ESjHF5SMLxiW4RKB05Ql//KWfeTFs= github.com/cockroachdb/cockroach-go/v2 v2.2.0/go.mod h1:u3MiKYGupPPjkn3ozknpMUpxPaNLTFWAya419/zv6eI= github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= -github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w= -github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 h1:8UrgZ3GkP4i/CLijOJx79Yu+etlyjdBU4sfcs2WYQMs= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0= -github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= -github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= github.com/georgysavva/scany/v2 v2.0.0 h1:RGXqxDv4row7/FYoK8MRXAZXqoWF/NM+NP0q50k3DKU= github.com/georgysavva/scany/v2 v2.0.0/go.mod h1:sigOdh+0qb/+aOs3TVhehVT10p8qJL7K/Zhyz8vWo38= github.com/go-chi/chi v1.5.4 h1:QHdzF2szwjqVV4wmByUnTcsbIg7UGaQ0tPF2t5GcAIs= @@ -33,8 +17,8 @@ github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM= github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og= github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= -github.com/gobwas/ws v1.1.0 h1:7RFti/xnNkMJnrK7D1yQ/iCIB5OrrY/54/H930kIbHA= -github.com/gobwas/ws v1.1.0/go.mod h1:nzvNcVha5eUziGrbxFCo6qFIojQHjJV5cLYIbezhfL0= +github.com/gobwas/ws v1.4.0 h1:CTaoG1tojrh4ucGPcoJFiAQUAsEWekEWvLy7GsVNqGs= +github.com/gobwas/ws v1.4.0/go.mod h1:G3gNqMNtPppf5XUz7O4shetPpcZ1VJ7zt18dlUeakrc= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= @@ -42,8 +26,8 @@ github.com/gofrs/flock v0.8.1 h1:+gYjHKf32LDeiEEFhQaotPbLuUXjY5ZqxKgXy7n59aw= github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= github.com/golang/mock v1.7.0-rc.1 h1:YojYx61/OLFsiv6Rw1Z96LpldJIy31o+UHmwAUMJ6/U= github.com/golang/mock v1.7.0-rc.1/go.mod h1:s42URUywIqd+OcERslBJvOjepvNymP31m3q8d/GkuRs= -github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= -github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= @@ -72,40 +56,20 @@ github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNB github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= github.com/lib/pq v1.10.0 h1:Zx5DJFEYQXio93kgXnQ09fXNiUKsqv4OUEu2UtGcB1E= github.com/lib/pq v1.10.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= -github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40= github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= -github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4= -github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88= -github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= -github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= -github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI= -github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo= -github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= -github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= -github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo= -github.com/muesli/termenv v0.15.2/go.mod h1:Epx+iuz8sNs7mNKhxzH4fWXGNpZwUaJKRS1noLXviQ8= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= -github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= -github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/zerolog v1.28.0 h1:MirSo27VyNi7RJYP3078AA1+Cyzd2GB66qy3aUHvsWY= github.com/rs/zerolog v1.28.0/go.mod h1:NILgTygv/Uej1ra5XxGf82ZFSLk58MFGAUS2o6usyD0= -github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= -github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/sahilm/fuzzy v0.1.1 h1:ceu5RHF8DGgoi+/dR5PsECjCDH1BE3Fnmpo7aVXOdRA= -github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y= github.com/samber/do v1.5.1 h1:32/S8RgoKYa2wpf8TrakzyOFj0C/QQV4df09x1nza7I= github.com/samber/do v1.5.1/go.mod h1:DWqBvumy8dyb2vEnYZE7D7zaVEB64J45B0NjTlY/M4k= github.com/samber/lo v1.37.0 h1:XjVcB8g6tgUp8rsPsJ2CvhClfImrpL04YpQHXeHPhRw= @@ -123,14 +87,8 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/therealpaulgg/ssh-sync v0.3.0 h1:XFgcZ3JcccqmPFinWmweNPAYwX2yFiwbCQAJsjaFIq8= -github.com/therealpaulgg/ssh-sync v0.3.0/go.mod h1:vfadGVAZqMe5QLSgWuBwvnLsrJPY3Lr2yRAIMFHaCKk= -github.com/therealpaulgg/ssh-sync v1.2.2 h1:EzRtkHLF9vvG4HyfnUlRCWtpq2x3IMPhI18nm7jgHFo= -github.com/therealpaulgg/ssh-sync v1.2.2/go.mod h1:lc90qMx77ydUuUw/ezkJb0eRlzeRlCWLWLW/RORbQsI= -github.com/urfave/cli/v2 v2.23.7 h1:YHDQ46s3VghFHFf1DdF+Sh7H4RqhcM+t0TmZRJx4oJY= -github.com/urfave/cli/v2 v2.23.7/go.mod h1:GHupkWPMM0M/sj1a2b4wUrWBPzazNrIjouW6fmdJLxc= -github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU= -github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsrSj3ccvlPHLoLsHnpR27oXr4ZE984MbSER8= +github.com/therealpaulgg/ssh-sync-common v0.0.1 h1:jGF8W/mS7YE0Le8jny+qfNYUcTayN1pfsp71QXFC9Ys= +github.com/therealpaulgg/ssh-sync-common v0.0.1/go.mod h1:eGg17M5ihJpAIJ7RDXot0UDK6K4wRTHciY1rPJolaDU= github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -149,10 +107,8 @@ golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201207223542-d4d67f95c62d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/pkg/crypto/pqc.go b/pkg/crypto/pqc.go new file mode 100644 index 0000000..f44591e --- /dev/null +++ b/pkg/crypto/pqc.go @@ -0,0 +1,151 @@ +package crypto + +import ( + "encoding/base64" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "strings" + "time" + + "filippo.io/mldsa" + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwk" +) + +type KeyType int + +const ( + KeyTypeUnknown KeyType = iota + KeyTypeECDSA + KeyTypeMLDSA +) + +func DetectKeyType(pemBytes []byte) KeyType { + block, _ := pem.Decode(pemBytes) + if block == nil { + return KeyTypeUnknown + } + switch block.Type { + case "PUBLIC KEY": + return KeyTypeECDSA + case "MLDSA PUBLIC KEY": + return KeyTypeMLDSA + default: + return KeyTypeUnknown + } +} + +func ParseMLDSAPublicKey(pemBytes []byte) (*mldsa.PublicKey, error) { + block, _ := pem.Decode(pemBytes) + if block == nil { + return nil, errors.New("failed to decode PEM block") + } + if block.Type != "MLDSA PUBLIC KEY" { + return nil, fmt.Errorf("unexpected PEM block type: %s", block.Type) + } + pk, err := mldsa.NewPublicKey(mldsa.MLDSA65(), block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse ML-DSA public key: %w", err) + } + return pk, nil +} + +func ValidatePublicKey(pemBytes []byte) (KeyType, error) { + kt := DetectKeyType(pemBytes) + switch kt { + case KeyTypeECDSA: + key, err := jwk.ParseKey(pemBytes, jwk.WithPEM(true)) + if err != nil { + return KeyTypeUnknown, fmt.Errorf("invalid ECDSA key: %w", err) + } + if key.KeyType() != jwa.EC { + return KeyTypeUnknown, errors.New("key is not EC type") + } + return KeyTypeECDSA, nil + case KeyTypeMLDSA: + if _, err := ParseMLDSAPublicKey(pemBytes); err != nil { + return KeyTypeUnknown, err + } + return KeyTypeMLDSA, nil + default: + return KeyTypeUnknown, errors.New("unsupported key type") + } +} + +type jwtHeader struct { + Alg string `json:"alg"` + Typ string `json:"typ"` +} + +func DetectJWTAlgorithm(tokenString string) (string, error) { + parts := strings.SplitN(tokenString, ".", 3) + if len(parts) != 3 { + return "", errors.New("invalid JWT format") + } + headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0]) + if err != nil { + return "", fmt.Errorf("failed to decode JWT header: %w", err) + } + var header jwtHeader + if err := json.Unmarshal(headerBytes, &header); err != nil { + return "", fmt.Errorf("failed to parse JWT header: %w", err) + } + return header.Alg, nil +} + +type jwtClaims struct { + Username string `json:"username"` + Machine string `json:"machine"` + Exp float64 `json:"exp"` +} + +func ExtractJWTClaims(tokenString string) (username, machine string, err error) { + parts := strings.SplitN(tokenString, ".", 3) + if len(parts) != 3 { + return "", "", errors.New("invalid JWT format") + } + payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return "", "", fmt.Errorf("failed to decode JWT payload: %w", err) + } + var claims jwtClaims + if err := json.Unmarshal(payloadBytes, &claims); err != nil { + return "", "", fmt.Errorf("failed to parse JWT claims: %w", err) + } + return claims.Username, claims.Machine, nil +} + +func VerifyMLDSAJWT(tokenString string, pubKey *mldsa.PublicKey) error { + parts := strings.SplitN(tokenString, ".", 3) + if len(parts) != 3 { + return errors.New("invalid JWT format") + } + + signedContent := []byte(parts[0] + "." + parts[1]) + + sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2]) + if err != nil { + return fmt.Errorf("failed to decode signature: %w", err) + } + + if err := mldsa.Verify(pubKey, signedContent, sigBytes, nil); err != nil { + return errors.New("ML-DSA signature verification failed") + } + + payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return fmt.Errorf("failed to decode payload: %w", err) + } + var claims jwtClaims + if err := json.Unmarshal(payloadBytes, &claims); err != nil { + return fmt.Errorf("failed to parse claims: %w", err) + } + + if int64(claims.Exp) <= time.Now().Unix() { + return errors.New("token expired") + } + + return nil +} diff --git a/pkg/crypto/pqc_test.go b/pkg/crypto/pqc_test.go new file mode 100644 index 0000000..6aed51f --- /dev/null +++ b/pkg/crypto/pqc_test.go @@ -0,0 +1,193 @@ +package crypto + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "encoding/base64" + "encoding/pem" + "fmt" + "strings" + "testing" + "time" + + "filippo.io/mldsa" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func generateECDSAPEM(t *testing.T) []byte { + t.Helper() + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + pubBytes, err := x509.MarshalPKIXPublicKey(&priv.PublicKey) + require.NoError(t, err) + return pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubBytes}) +} + +func generateMLDSAPEM(t *testing.T) ([]byte, *mldsa.PublicKey, *mldsa.PrivateKey) { + t.Helper() + priv, err := mldsa.GenerateKey(mldsa.MLDSA65()) + require.NoError(t, err) + pub := priv.PublicKey() + pemBytes := pem.EncodeToMemory(&pem.Block{Type: "MLDSA PUBLIC KEY", Bytes: pub.Bytes()}) + return pemBytes, pub, priv +} + +func signMLDSAJWT(t *testing.T, priv *mldsa.PrivateKey, username, machine string, exp time.Time) string { + t.Helper() + header := `{"alg":"MLDSA","typ":"JWT"}` + claims := fmt.Sprintf( + `{"iss":"test","iat":%d,"exp":%d,"username":"%s","machine":"%s"}`, + time.Now().Add(-1*time.Minute).Unix(), exp.Unix(), username, machine, + ) + h := base64.RawURLEncoding.EncodeToString([]byte(header)) + c := base64.RawURLEncoding.EncodeToString([]byte(claims)) + signingInput := h + "." + c + sig, err := priv.Sign(nil, []byte(signingInput), nil) + require.NoError(t, err) + s := base64.RawURLEncoding.EncodeToString(sig) + return signingInput + "." + s +} + +func TestDetectKeyType_ECDSA(t *testing.T) { + pemBytes := generateECDSAPEM(t) + assert.Equal(t, KeyTypeECDSA, DetectKeyType(pemBytes)) +} + +func TestDetectKeyType_MLDSA(t *testing.T) { + pemBytes, _, _ := generateMLDSAPEM(t) + assert.Equal(t, KeyTypeMLDSA, DetectKeyType(pemBytes)) +} + +func TestDetectKeyType_Invalid(t *testing.T) { + assert.Equal(t, KeyTypeUnknown, DetectKeyType([]byte("not a pem"))) +} + +func TestDetectKeyType_UnknownBlockType(t *testing.T) { + pemBytes := pem.EncodeToMemory(&pem.Block{Type: "SOMETHING ELSE", Bytes: []byte{1, 2, 3}}) + assert.Equal(t, KeyTypeUnknown, DetectKeyType(pemBytes)) +} + +func TestParseMLDSAPublicKey_Valid(t *testing.T) { + pemBytes, _, _ := generateMLDSAPEM(t) + pk, err := ParseMLDSAPublicKey(pemBytes) + require.NoError(t, err) + assert.NotNil(t, pk) +} + +func TestParseMLDSAPublicKey_WrongPEMType(t *testing.T) { + pemBytes := generateECDSAPEM(t) + _, err := ParseMLDSAPublicKey(pemBytes) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unexpected PEM block type") +} + +func TestParseMLDSAPublicKey_InvalidData(t *testing.T) { + pemBytes := pem.EncodeToMemory(&pem.Block{Type: "MLDSA PUBLIC KEY", Bytes: []byte{1, 2, 3}}) + _, err := ParseMLDSAPublicKey(pemBytes) + assert.Error(t, err) +} + +func TestValidatePublicKey_ECDSA(t *testing.T) { + pemBytes := generateECDSAPEM(t) + kt, err := ValidatePublicKey(pemBytes) + require.NoError(t, err) + assert.Equal(t, KeyTypeECDSA, kt) +} + +func TestValidatePublicKey_MLDSA(t *testing.T) { + pemBytes, _, _ := generateMLDSAPEM(t) + kt, err := ValidatePublicKey(pemBytes) + require.NoError(t, err) + assert.Equal(t, KeyTypeMLDSA, kt) +} + +func TestValidatePublicKey_Invalid(t *testing.T) { + _, err := ValidatePublicKey([]byte("garbage")) + assert.Error(t, err) +} + +func TestDetectJWTAlgorithm_ES512(t *testing.T) { + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"ES512","typ":"JWT"}`)) + payload := base64.RawURLEncoding.EncodeToString([]byte(`{}`)) + sig := base64.RawURLEncoding.EncodeToString([]byte("fakesig")) + token := header + "." + payload + "." + sig + + alg, err := DetectJWTAlgorithm(token) + require.NoError(t, err) + assert.Equal(t, "ES512", alg) +} + +func TestDetectJWTAlgorithm_MLDSA(t *testing.T) { + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"MLDSA","typ":"JWT"}`)) + payload := base64.RawURLEncoding.EncodeToString([]byte(`{}`)) + sig := base64.RawURLEncoding.EncodeToString([]byte("fakesig")) + token := header + "." + payload + "." + sig + + alg, err := DetectJWTAlgorithm(token) + require.NoError(t, err) + assert.Equal(t, "MLDSA", alg) +} + +func TestDetectJWTAlgorithm_InvalidFormat(t *testing.T) { + _, err := DetectJWTAlgorithm("not.a.valid-base64!!!") + assert.Error(t, err) +} + +func TestExtractJWTClaims(t *testing.T) { + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"MLDSA"}`)) + payload := base64.RawURLEncoding.EncodeToString([]byte(`{"username":"alice","machine":"laptop"}`)) + sig := base64.RawURLEncoding.EncodeToString([]byte("sig")) + token := header + "." + payload + "." + sig + + username, machine, err := ExtractJWTClaims(token) + require.NoError(t, err) + assert.Equal(t, "alice", username) + assert.Equal(t, "laptop", machine) +} + +func TestExtractJWTClaims_InvalidFormat(t *testing.T) { + _, _, err := ExtractJWTClaims("not-a-jwt") + assert.Error(t, err) +} + +func TestVerifyMLDSAJWT_Valid(t *testing.T) { + _, pub, priv := generateMLDSAPEM(t) + token := signMLDSAJWT(t, priv, "user1", "machine1", time.Now().Add(5*time.Minute)) + err := VerifyMLDSAJWT(token, pub) + assert.NoError(t, err) +} + +func TestVerifyMLDSAJWT_Expired(t *testing.T) { + _, pub, priv := generateMLDSAPEM(t) + token := signMLDSAJWT(t, priv, "user1", "machine1", time.Now().Add(-5*time.Minute)) + err := VerifyMLDSAJWT(token, pub) + assert.Error(t, err) + assert.Contains(t, err.Error(), "expired") +} + +func TestVerifyMLDSAJWT_BadSignature(t *testing.T) { + _, _, priv := generateMLDSAPEM(t) + token := signMLDSAJWT(t, priv, "user1", "machine1", time.Now().Add(5*time.Minute)) + + priv2, _ := mldsa.GenerateKey(mldsa.MLDSA65()) + err := VerifyMLDSAJWT(token, priv2.PublicKey()) + assert.Error(t, err) + assert.Contains(t, err.Error(), "verification failed") +} + +func TestVerifyMLDSAJWT_TamperedPayload(t *testing.T) { + _, pub, priv := generateMLDSAPEM(t) + token := signMLDSAJWT(t, priv, "user1", "machine1", time.Now().Add(5*time.Minute)) + + parts := strings.SplitN(token, ".", 3) + require.Len(t, parts, 3) + parts[1] = base64.RawURLEncoding.EncodeToString([]byte(`{"username":"evil","machine":"bad","exp":9999999999}`)) + tampered := parts[0] + "." + parts[1] + "." + parts[2] + + err := VerifyMLDSAJWT(tampered, pub) + assert.Error(t, err) +} + diff --git a/pkg/database/repository/machine.go b/pkg/database/repository/machine.go index 70d828c..eb01a84 100644 --- a/pkg/database/repository/machine.go +++ b/pkg/database/repository/machine.go @@ -20,6 +20,7 @@ type MachineRepository interface { CreateMachine(machine *models.Machine) (*models.Machine, error) CreateMachineTx(machine *models.Machine, tx pgx.Tx) (*models.Machine, error) GetUserMachines(id uuid.UUID) ([]models.Machine, error) + UpdateMachinePublicKey(id uuid.UUID, publicKey []byte) error } type MachineRepo struct { @@ -107,6 +108,16 @@ func (repo *MachineRepo) CreateMachineTx(machine *models.Machine, tx pgx.Tx) (*m return newMachine, nil } +func (repo *MachineRepo) UpdateMachinePublicKey(id uuid.UUID, publicKey []byte) error { + q := do.MustInvoke[database.DataAccessor](repo.Injector) + _, err := q.GetConnection().Exec( + context.TODO(), + "UPDATE machines SET public_key = $1 WHERE id = $2", + publicKey, id, + ) + return err +} + func (repo *MachineRepo) GetUserMachines(id uuid.UUID) ([]models.Machine, error) { q := do.MustInvoke[query.QueryService[models.Machine]](repo.Injector) machines, err := q.Query("select * from machines where user_id = $1", id) diff --git a/pkg/database/repository/machinemock.go b/pkg/database/repository/machinemock.go index 170bb52..b162d42 100644 --- a/pkg/database/repository/machinemock.go +++ b/pkg/database/repository/machinemock.go @@ -110,6 +110,20 @@ func (mr *MockMachineRepositoryMockRecorder) GetMachineByNameAndUser(machineName return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMachineByNameAndUser", reflect.TypeOf((*MockMachineRepository)(nil).GetMachineByNameAndUser), machineName, userID) } +// UpdateMachinePublicKey mocks base method. +func (m *MockMachineRepository) UpdateMachinePublicKey(id uuid.UUID, publicKey []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateMachinePublicKey", id, publicKey) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateMachinePublicKey indicates an expected call of UpdateMachinePublicKey. +func (mr *MockMachineRepositoryMockRecorder) UpdateMachinePublicKey(id, publicKey interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateMachinePublicKey", reflect.TypeOf((*MockMachineRepository)(nil).UpdateMachinePublicKey), id, publicKey) +} + // GetUserMachines mocks base method. func (m *MockMachineRepository) GetUserMachines(id uuid.UUID) ([]models.Machine, error) { m.ctrl.T.Helper() diff --git a/pkg/web/live/main.go b/pkg/web/live/main.go index c64efa1..95ce8a5 100644 --- a/pkg/web/live/main.go +++ b/pkg/web/live/main.go @@ -14,11 +14,12 @@ import ( "github.com/rs/zerolog/log" "github.com/samber/do" "github.com/sethvargo/go-diceware/diceware" + "github.com/therealpaulgg/ssh-sync-common/pkg/dto" + "github.com/therealpaulgg/ssh-sync-common/pkg/wsutils" + pqc "github.com/therealpaulgg/ssh-sync-server/pkg/crypto" "github.com/therealpaulgg/ssh-sync-server/pkg/database/models" "github.com/therealpaulgg/ssh-sync-server/pkg/database/repository" "github.com/therealpaulgg/ssh-sync-server/pkg/web/middleware/context_keys" - "github.com/therealpaulgg/ssh-sync/pkg/dto" - "github.com/therealpaulgg/ssh-sync/pkg/utils" ) // Computer A creates a live connection. @@ -37,7 +38,7 @@ type ChallengeResponse struct { type ChallengeSession struct { Username string ChallengeAccepted chan bool - ChallengerChannel chan []byte + ChallengerChannel chan *dto.PublicKeyDto ResponderChannel chan []byte } @@ -84,7 +85,7 @@ func MachineChallengeResponseHandler(i *do.Injector, r *http.Request, w http.Res log.Warn().Msg("Could not get user from context") return } - foo, err := utils.ReadClientMessage[dto.ChallengeResponseDto](&conn) + foo, err := wsutils.ReadClientMessage[dto.ChallengeResponseDto](&conn) if err != nil { log.Err(err).Msg("Error reading client message") return @@ -92,7 +93,7 @@ func MachineChallengeResponseHandler(i *do.Injector, r *http.Request, w http.Res chalChan, ok := ChallengeResponseDict.ReadChallenge(foo.Data.Challenge) if !ok { log.Warn().Msg("Could not find challenge in dict") - if err := utils.WriteServerError[dto.ChallengeSuccessEncryptedKeyDto](&conn, "Invalid challenge response."); err != nil { + if err := wsutils.WriteServerError[dto.ChallengeSuccessEncryptedKeyDto](&conn, "Invalid challenge response."); err != nil { log.Err(err).Msg("Error writing server error") } return @@ -108,19 +109,20 @@ func MachineChallengeResponseHandler(i *do.Injector, r *http.Request, w http.Res key := <-chalChan.ChallengerChannel if key == nil { log.Debug().Msg("Response from challenger channel - key is nil. Exiting.") - if err := utils.WriteServerError[dto.ChallengeSuccessEncryptedKeyDto](&conn, "Error responding to challenge - client abruptly closed connection."); err != nil { + if err := wsutils.WriteServerError[dto.ChallengeSuccessEncryptedKeyDto](&conn, "Error responding to challenge - client abruptly closed connection."); err != nil { log.Err(err).Msg("Error writing server error") } return } keys := dto.ChallengeSuccessEncryptedKeyDto{ - PublicKey: key, + PublicKey: key.PublicKey, + EncapsulationKey: key.EncapsulationKey, } - if err := utils.WriteServerMessage(&conn, keys); err != nil { + if err := wsutils.WriteServerMessage(&conn, keys); err != nil { log.Err(err).Msg("Error writing server message") return } - encMasterKeyDto, err := utils.ReadClientMessage[dto.EncryptedMasterKeyDto](&conn) + encMasterKeyDto, err := wsutils.ReadClientMessage[dto.EncryptedMasterKeyDto](&conn) if err != nil { log.Err(err).Msg("Error reading client message") return @@ -139,9 +141,9 @@ func NewMachineChallenge(i *do.Injector, r *http.Request, w http.ResponseWriter) func NewMachineChallengeHandler(i *do.Injector, r *http.Request, w http.ResponseWriter, c *net.Conn) { conn := *c - defer conn.Close() + defer conn.Close() // first message sent should be JSON payload - userMachine, err := utils.ReadClientMessage[dto.UserMachineDto](&conn) + userMachine, err := wsutils.ReadClientMessage[dto.UserMachineDto](&conn) if err != nil { log.Err(err).Msg("Error reading client message") return @@ -149,7 +151,7 @@ func NewMachineChallengeHandler(i *do.Injector, r *http.Request, w http.Response userRepo := do.MustInvoke[repository.UserRepository](i) user, err := userRepo.GetUserByUsername(userMachine.Data.Username) if errors.Is(err, sql.ErrNoRows) || user == nil { - if err := utils.WriteServerError[dto.MessageDto](&conn, "User not found"); err != nil { + if err := wsutils.WriteServerError[dto.MessageDto](&conn, "User not found"); err != nil { log.Err(err).Msg("Error writing server error") } return @@ -162,7 +164,7 @@ func NewMachineChallengeHandler(i *do.Injector, r *http.Request, w http.Response machine, err := machineRepo.GetMachineByNameAndUser(userMachine.Data.MachineName, user.ID) // if the machine already exists, reject if err == nil && machine.ID != uuid.Nil { - if err = utils.WriteServerError[dto.MessageDto](&conn, "Machine already exists"); err != nil { + if err = wsutils.WriteServerError[dto.MessageDto](&conn, "Machine already exists"); err != nil { log.Err(err).Msg("Error writing server error") } return @@ -178,13 +180,13 @@ func NewMachineChallengeHandler(i *do.Injector, r *http.Request, w http.Response words, err := diceware.GenerateWithWordList(3, diceware.WordListEffLarge()) if err != nil { log.Err(err).Msg("Error generating diceware") - if err := utils.WriteServerError[dto.MessageDto](&conn, "Error generating diceware"); err != nil { + if err := wsutils.WriteServerError[dto.MessageDto](&conn, "Error generating diceware"); err != nil { log.Err(err).Msg("Error writing server error") } return } challengePhrase := strings.Join(words, "-") - if err := utils.WriteServerMessage(&conn, dto.MessageDto{Message: challengePhrase}); err != nil { + if err := wsutils.WriteServerMessage(&conn, dto.MessageDto{Message: challengePhrase}); err != nil { log.Err(err).Msg("Error writing challenge phrase") return } @@ -198,7 +200,7 @@ func NewMachineChallengeHandler(i *do.Injector, r *http.Request, w http.Response ChallengeResponseDict.WriteChallenge(challengePhrase, ChallengeSession{ Username: user.Username, ChallengeAccepted: make(chan bool), - ChallengerChannel: make(chan []byte), + ChallengerChannel: make(chan *dto.PublicKeyDto), ResponderChannel: make(chan []byte), }) defer func() { @@ -252,33 +254,40 @@ func NewMachineChallengeHandler(i *do.Injector, r *http.Request, w http.Response challengeResult := <-challengeResponse if !challengeResult { - if err := utils.WriteServerError[dto.MessageDto](&conn, "Challenge timed out"); err != nil { + if err := wsutils.WriteServerError[dto.MessageDto](&conn, "Challenge timed out"); err != nil { log.Err(err).Msg("Error writing server error") } return } - if err := utils.WriteServerMessage(&conn, dto.MessageDto{Message: "Challenge accepted!"}); err != nil { + if err := wsutils.WriteServerMessage(&conn, dto.MessageDto{Message: "Challenge accepted!"}); err != nil { log.Err(err).Msg("Error writing challenge accepted") return } - pubkey, err := utils.ReadClientMessage[dto.PublicKeyDto](&conn) + pubkey, err := wsutils.ReadClientMessage[dto.PublicKeyDto](&conn) if err != nil { log.Err(err).Msg("Error reading client message") return } - cha.ChallengerChannel <- pubkey.Data.PublicKey + if _, err := pqc.ValidatePublicKey(pubkey.Data.PublicKey); err != nil { + log.Err(err).Msg("Invalid public key format in challenge flow") + if err := wsutils.WriteServerError[dto.MessageDto](&conn, "Invalid public key format"); err != nil { + log.Err(err).Msg("Error writing server error") + } + return + } + cha.ChallengerChannel <- &pubkey.Data encryptedMasterKey := <-cha.ResponderChannel machine.PublicKey = pubkey.Data.PublicKey if _, err = machineRepo.CreateMachine(machine); err != nil { log.Err(err).Msg("Error creating machine") return } - if err := utils.WriteServerMessage(&conn, dto.EncryptedMasterKeyDto{EncryptedMasterKey: encryptedMasterKey}); err != nil { + if err := wsutils.WriteServerMessage(&conn, dto.EncryptedMasterKeyDto{EncryptedMasterKey: encryptedMasterKey}); err != nil { log.Err(err).Msg("Error writing encrypted master key") return } - if err := utils.WriteServerMessage(&conn, dto.MessageDto{Message: "Everything is done, you can now use ssh-sync"}); err != nil { + if err := wsutils.WriteServerMessage(&conn, dto.MessageDto{Message: "Everything is done, you can now use ssh-sync"}); err != nil { log.Err(err).Msg("Error writing final message") return } diff --git a/pkg/web/middleware/auth.go b/pkg/web/middleware/auth.go index 4130861..e385cd2 100644 --- a/pkg/web/middleware/auth.go +++ b/pkg/web/middleware/auth.go @@ -13,6 +13,7 @@ import ( "github.com/lestrrat-go/jwx/v2/jwt" "github.com/rs/zerolog/log" "github.com/samber/do" + pqc "github.com/therealpaulgg/ssh-sync-server/pkg/crypto" "github.com/therealpaulgg/ssh-sync-server/pkg/database/repository" "github.com/therealpaulgg/ssh-sync-server/pkg/web/middleware/context_keys" ) @@ -36,22 +37,47 @@ func ConfigureAuth(i *do.Injector) func(http.Handler) http.Handler { w.WriteHeader(http.StatusUnauthorized) return } - token, err := jwt.ParseString(tokenString, jwt.WithVerify(false)) + + alg, err := pqc.DetectJWTAlgorithm(tokenString) if err != nil { - log.Debug().Msg(fmt.Sprintf("Error parsing JWT: %s", err)) - w.WriteHeader(http.StatusUnauthorized) - return - } - username, ok := token.PrivateClaims()["username"].(string) - if username == "" || !ok { + log.Debug().Msg(fmt.Sprintf("Error detecting JWT algorithm: %s", err)) w.WriteHeader(http.StatusUnauthorized) return } - machine, ok := token.PrivateClaims()["machine"].(string) - if machine == "" || !ok { + + var username, machine string + switch alg { + case "ES256", "ES512": + token, err := jwt.ParseString(tokenString, jwt.WithVerify(false)) + if err != nil { + log.Debug().Msg(fmt.Sprintf("Error parsing JWT: %s", err)) + w.WriteHeader(http.StatusUnauthorized) + return + } + var ok bool + username, ok = token.PrivateClaims()["username"].(string) + if username == "" || !ok { + w.WriteHeader(http.StatusUnauthorized) + return + } + machine, ok = token.PrivateClaims()["machine"].(string) + if machine == "" || !ok { + w.WriteHeader(http.StatusUnauthorized) + return + } + case "MLDSA": + username, machine, err = pqc.ExtractJWTClaims(tokenString) + if err != nil || username == "" || machine == "" { + log.Debug().Msg(fmt.Sprintf("Error extracting PQ JWT claims: %v", err)) + w.WriteHeader(http.StatusUnauthorized) + return + } + default: + log.Debug().Msg(fmt.Sprintf("Unsupported JWT algorithm: %s", alg)) w.WriteHeader(http.StatusUnauthorized) return } + userRepo := do.MustInvoke[repository.UserRepository](i) user, err := userRepo.GetUserByUsername(username) if err != nil { @@ -66,22 +92,37 @@ func ConfigureAuth(i *do.Injector) func(http.Handler) http.Handler { w.WriteHeader(http.StatusUnauthorized) return } - key, err := jwk.ParseKey(m.PublicKey, jwk.WithPEM(true)) - if err != nil { - log.Error().Msg(err.Error()) - w.WriteHeader(http.StatusInternalServerError) - return - } - if _, err := jwt.ParseRequest(r, jwt.WithKey(jwa.ES512, key)); err != nil { - log.Debug().Msg(fmt.Sprintf("Error parsing JWT: %s", err)) - w.WriteHeader(http.StatusUnauthorized) - return + + switch alg { + case "ES256", "ES512": + key, err := jwk.ParseKey(m.PublicKey, jwk.WithPEM(true)) + if err != nil { + log.Error().Msg(err.Error()) + w.WriteHeader(http.StatusInternalServerError) + return + } + if _, err := jwt.ParseRequest(r, jwt.WithKey(jwa.SignatureAlgorithm(alg), key)); err != nil { + log.Debug().Msg(fmt.Sprintf("Error verifying JWT: %s", err)) + w.WriteHeader(http.StatusUnauthorized) + return + } + case "MLDSA": + pubKey, err := pqc.ParseMLDSAPublicKey(m.PublicKey) + if err != nil { + log.Error().Msg(fmt.Sprintf("Error parsing ML-DSA key: %s", err)) + w.WriteHeader(http.StatusInternalServerError) + return + } + if err := pqc.VerifyMLDSAJWT(tokenString, pubKey); err != nil { + log.Debug().Msg(fmt.Sprintf("ML-DSA JWT verification failed: %s", err)) + w.WriteHeader(http.StatusUnauthorized) + return + } } + ctx := context.WithValue(r.Context(), context_keys.UserContextKey, user) ctx = context.WithValue(ctx, context_keys.MachineContextKey, m) next.ServeHTTP(w, r.WithContext(ctx)) }) } } - -// Auth middleware: parse a JWT signed with ES512 and verify it with the public key diff --git a/pkg/web/middleware/auth_test.go b/pkg/web/middleware/auth_test.go index 65e3010..ccb3f32 100644 --- a/pkg/web/middleware/auth_test.go +++ b/pkg/web/middleware/auth_test.go @@ -208,24 +208,13 @@ func TestConfigureAuthUnsignedToken(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - priv, pub, err := testutils.GenerateTestKeys() - if err != nil { - t.Fatal(err) - } - pubBytes, _, err := testutils.EncodeToPem(priv, pub) - if err != nil { - t.Fatal(err) - } - // Mock user and machine data - user := &models.User{ID: uuid.New(), Username: "testuser"} - machine := &models.Machine{ID: uuid.New(), Name: "testmachine", UserID: user.ID, PublicKey: pubBytes} - // Create HMAC signed token + // Create unsigned token (algorithm "none") builder := jwt.NewBuilder() builder.Issuer("github.com/therealpaulgg/ssh-sync") builder.IssuedAt(time.Now()) builder.Expiration(time.Now().Add(time.Minute)) - builder.Claim("username", user.Username) - builder.Claim("machine", machine.Name) + builder.Claim("username", "testuser") + builder.Claim("machine", "testmachine") tok, err := builder.Build() if err != nil { t.Fatal(err) @@ -236,18 +225,6 @@ func TestConfigureAuthUnsignedToken(t *testing.T) { t.Fatal(err) } - mockUserRepo := repository.NewMockUserRepository(ctrl) - mockUserRepo.EXPECT().GetUserByUsername(user.Username).Return(user, nil).Times(1) - do.Provide(i, func(i *do.Injector) (repository.UserRepository, error) { - return mockUserRepo, nil - }) - - mockMachineRepo := repository.NewMockMachineRepository(ctrl) - mockMachineRepo.EXPECT().GetMachineByNameAndUser(machine.Name, user.ID).Return(machine, nil).Times(1) - do.Provide(i, func(i *do.Injector) (repository.MachineRepository, error) { - return mockMachineRepo, nil - }) - // Mock http request with Authorization header req, err := http.NewRequest("GET", "/", nil) if err != nil { @@ -255,11 +232,10 @@ func TestConfigureAuthUnsignedToken(t *testing.T) { } req.Header.Set("Authorization", "Bearer "+string(token)) - // Act + // Act - "none" algorithm is now rejected early before DB lookups rr := httptest.NewRecorder() f := ConfigureAuth(i) f(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Dummy handler w.WriteHeader(http.StatusOK) })).ServeHTTP(rr, req) @@ -357,3 +333,163 @@ func TestConfigureAuthFakeToken(t *testing.T) { // Assert assert.Equal(t, http.StatusUnauthorized, rr.Code) } + +func TestConfigureAuth_MLDSA(t *testing.T) { + // Arrange + i := do.New() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + pub, priv, err := testutils.GenerateMLDSATestKeys() + if err != nil { + t.Fatal(err) + } + pubPEM, err := testutils.EncodeMLDSAToPem(pub) + if err != nil { + t.Fatal(err) + } + + user := &models.User{ID: uuid.New(), Username: "testuser"} + machine := &models.Machine{ID: uuid.New(), Name: "testmachine", UserID: user.ID, PublicKey: pubPEM} + token, err := testutils.GenerateMLDSATestToken(user.Username, machine.Name, priv) + if err != nil { + t.Fatal(err) + } + + mockUserRepo := repository.NewMockUserRepository(ctrl) + mockUserRepo.EXPECT().GetUserByUsername(user.Username).Return(user, nil).Times(1) + do.Provide(i, func(i *do.Injector) (repository.UserRepository, error) { + return mockUserRepo, nil + }) + + mockMachineRepo := repository.NewMockMachineRepository(ctrl) + mockMachineRepo.EXPECT().GetMachineByNameAndUser(machine.Name, user.ID).Return(machine, nil).Times(1) + do.Provide(i, func(i *do.Injector) (repository.MachineRepository, error) { + return mockMachineRepo, nil + }) + + req, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Authorization", "Bearer "+token) + + // Act + rr := httptest.NewRecorder() + f := ConfigureAuth(i) + f(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })).ServeHTTP(rr, req) + + // Assert + assert.Equal(t, http.StatusOK, rr.Code) +} + +func TestConfigureAuth_MLDSA_WrongKey(t *testing.T) { + // Arrange + i := do.New() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Generate two different keypairs + pub1, _, err := testutils.GenerateMLDSATestKeys() + if err != nil { + t.Fatal(err) + } + _, priv2, err := testutils.GenerateMLDSATestKeys() + if err != nil { + t.Fatal(err) + } + + // Store pub1 but sign with priv2 + pubPEM, err := testutils.EncodeMLDSAToPem(pub1) + if err != nil { + t.Fatal(err) + } + + user := &models.User{ID: uuid.New(), Username: "testuser"} + machine := &models.Machine{ID: uuid.New(), Name: "testmachine", UserID: user.ID, PublicKey: pubPEM} + token, err := testutils.GenerateMLDSATestToken(user.Username, machine.Name, priv2) + if err != nil { + t.Fatal(err) + } + + mockUserRepo := repository.NewMockUserRepository(ctrl) + mockUserRepo.EXPECT().GetUserByUsername(user.Username).Return(user, nil).Times(1) + do.Provide(i, func(i *do.Injector) (repository.UserRepository, error) { + return mockUserRepo, nil + }) + + mockMachineRepo := repository.NewMockMachineRepository(ctrl) + mockMachineRepo.EXPECT().GetMachineByNameAndUser(machine.Name, user.ID).Return(machine, nil).Times(1) + do.Provide(i, func(i *do.Injector) (repository.MachineRepository, error) { + return mockMachineRepo, nil + }) + + req, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Authorization", "Bearer "+token) + + // Act + rr := httptest.NewRecorder() + f := ConfigureAuth(i) + f(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })).ServeHTTP(rr, req) + + // Assert + assert.Equal(t, http.StatusUnauthorized, rr.Code) +} + +func TestConfigureAuth_MLDSA_Expired(t *testing.T) { + // Arrange + i := do.New() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + pub, priv, err := testutils.GenerateMLDSATestKeys() + if err != nil { + t.Fatal(err) + } + pubPEM, err := testutils.EncodeMLDSAToPem(pub) + if err != nil { + t.Fatal(err) + } + + user := &models.User{ID: uuid.New(), Username: "testuser"} + machine := &models.Machine{ID: uuid.New(), Name: "testmachine", UserID: user.ID, PublicKey: pubPEM} + token, err := testutils.GenerateExpiredMLDSATestToken(user.Username, machine.Name, priv) + if err != nil { + t.Fatal(err) + } + + mockUserRepo := repository.NewMockUserRepository(ctrl) + mockUserRepo.EXPECT().GetUserByUsername(user.Username).Return(user, nil).Times(1) + do.Provide(i, func(i *do.Injector) (repository.UserRepository, error) { + return mockUserRepo, nil + }) + + mockMachineRepo := repository.NewMockMachineRepository(ctrl) + mockMachineRepo.EXPECT().GetMachineByNameAndUser(machine.Name, user.ID).Return(machine, nil).Times(1) + do.Provide(i, func(i *do.Injector) (repository.MachineRepository, error) { + return mockMachineRepo, nil + }) + + req, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Authorization", "Bearer "+token) + + // Act + rr := httptest.NewRecorder() + f := ConfigureAuth(i) + f(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })).ServeHTTP(rr, req) + + // Assert + assert.Equal(t, http.StatusUnauthorized, rr.Code) +} diff --git a/pkg/web/router/routes/data.go b/pkg/web/router/routes/data.go index 6b1d1c3..1317ba1 100644 --- a/pkg/web/router/routes/data.go +++ b/pkg/web/router/routes/data.go @@ -13,12 +13,12 @@ import ( "github.com/rs/zerolog/log" "github.com/samber/do" "github.com/samber/lo" + "github.com/therealpaulgg/ssh-sync-common/pkg/dto" "github.com/therealpaulgg/ssh-sync-server/pkg/database/models" "github.com/therealpaulgg/ssh-sync-server/pkg/database/query" "github.com/therealpaulgg/ssh-sync-server/pkg/database/repository" "github.com/therealpaulgg/ssh-sync-server/pkg/web/middleware" "github.com/therealpaulgg/ssh-sync-server/pkg/web/middleware/context_keys" - "github.com/therealpaulgg/ssh-sync/pkg/dto" ) func getData(i *do.Injector) http.HandlerFunc { diff --git a/pkg/web/router/routes/data_test.go b/pkg/web/router/routes/data_test.go index d574940..97f6c77 100644 --- a/pkg/web/router/routes/data_test.go +++ b/pkg/web/router/routes/data_test.go @@ -16,12 +16,12 @@ import ( "github.com/google/uuid" "github.com/samber/do" "github.com/stretchr/testify/assert" + "github.com/therealpaulgg/ssh-sync-common/pkg/dto" "github.com/therealpaulgg/ssh-sync-server/pkg/database/models" "github.com/therealpaulgg/ssh-sync-server/pkg/database/query" "github.com/therealpaulgg/ssh-sync-server/pkg/database/repository" "github.com/therealpaulgg/ssh-sync-server/pkg/web/testutils" "github.com/therealpaulgg/ssh-sync-server/test/pgx" - "github.com/therealpaulgg/ssh-sync/pkg/dto" ) func TestGetData(t *testing.T) { diff --git a/pkg/web/router/routes/machine.go b/pkg/web/router/routes/machine.go index 3b53253..10a0349 100644 --- a/pkg/web/router/routes/machine.go +++ b/pkg/web/router/routes/machine.go @@ -4,6 +4,7 @@ import ( "database/sql" "encoding/json" "errors" + "io" "net/http" "github.com/go-chi/chi" @@ -11,11 +12,12 @@ import ( "github.com/rs/zerolog/log" "github.com/samber/do" "github.com/samber/lo" + "github.com/therealpaulgg/ssh-sync-common/pkg/dto" + pqc "github.com/therealpaulgg/ssh-sync-server/pkg/crypto" "github.com/therealpaulgg/ssh-sync-server/pkg/database/models" "github.com/therealpaulgg/ssh-sync-server/pkg/database/repository" "github.com/therealpaulgg/ssh-sync-server/pkg/web/middleware" "github.com/therealpaulgg/ssh-sync-server/pkg/web/middleware/context_keys" - "github.com/therealpaulgg/ssh-sync/pkg/dto" ) type DeleteRequest struct { @@ -109,11 +111,50 @@ func deleteMachine(i *do.Injector) http.HandlerFunc { } } +func updateMachineKey(i *do.Injector) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + machine, ok := r.Context().Value(context_keys.MachineContextKey).(*models.Machine) + if !ok { + w.WriteHeader(http.StatusInternalServerError) + return + } + err := r.ParseMultipartForm(32 << 20) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + file, _, err := r.FormFile("key") + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + defer file.Close() + fileBytes, err := io.ReadAll(file) + if err != nil { + log.Err(err).Msg("error reading key file") + w.WriteHeader(http.StatusInternalServerError) + return + } + if _, err := pqc.ValidatePublicKey(fileBytes); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + machineRepo := do.MustInvoke[repository.MachineRepository](i) + if err := machineRepo.UpdateMachinePublicKey(machine.ID, fileBytes); err != nil { + log.Err(err).Msg("error updating machine public key") + w.WriteHeader(http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + } +} + func MachineRoutes(i *do.Injector) chi.Router { r := chi.NewRouter() r.Use(middleware.ConfigureAuth(i)) r.Get("/{machineId}", getMachineById(i)) r.Get("/", getMachines(i)) r.Delete("/", deleteMachine(i)) + r.Put("/key", updateMachineKey(i)) return r } diff --git a/pkg/web/router/routes/machine_test.go b/pkg/web/router/routes/machine_test.go index 107bedf..3d37486 100644 --- a/pkg/web/router/routes/machine_test.go +++ b/pkg/web/router/routes/machine_test.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "fmt" + "mime/multipart" "net/http" "net/http/httptest" "testing" @@ -13,10 +14,10 @@ import ( "github.com/google/uuid" "github.com/samber/do" "github.com/stretchr/testify/assert" + "github.com/therealpaulgg/ssh-sync-common/pkg/dto" "github.com/therealpaulgg/ssh-sync-server/pkg/database/models" "github.com/therealpaulgg/ssh-sync-server/pkg/database/repository" "github.com/therealpaulgg/ssh-sync-server/pkg/web/testutils" - "github.com/therealpaulgg/ssh-sync/pkg/dto" ) func TestGetMachine(t *testing.T) { @@ -148,4 +149,110 @@ func TestDeleteMachine(t *testing.T) { } } +func TestUpdateMachineKey(t *testing.T) { + // Arrange + pub, _, err := testutils.GenerateMLDSATestKeys() + if err != nil { + t.Fatal(err) + } + pubPEM, err := testutils.EncodeMLDSAToPem(pub) + if err != nil { + t.Fatal(err) + } + + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + part, err := writer.CreateFormFile("key", "key") + if err != nil { + t.Fatal(err) + } + _, err = part.Write(pubPEM) + if err != nil { + t.Fatal(err) + } + err = writer.Close() + if err != nil { + t.Fatal(err) + } + + req, err := http.NewRequest("PUT", "/key", body) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", writer.FormDataContentType()) + + user := testutils.GenerateUser() + machine := &models.Machine{ + ID: uuid.New(), + UserID: user.ID, + Name: "test", + PublicKey: []byte("old-key"), + } + req = testutils.AddUserContext(req, user) + req = testutils.AddMachineContext(req, machine) + + injector := do.New() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockMachineRepo := repository.NewMockMachineRepository(ctrl) + mockMachineRepo.EXPECT().UpdateMachinePublicKey(machine.ID, pubPEM).Return(nil) + do.Provide(injector, func(i *do.Injector) (repository.MachineRepository, error) { + return mockMachineRepo, nil + }) + + // Act + rr := httptest.NewRecorder() + router := chi.NewRouter() + router.Put("/key", updateMachineKey(injector)) + router.ServeHTTP(rr, req) + + // Assert + assert.Equal(t, http.StatusOK, rr.Code) +} + +func TestUpdateMachineKey_InvalidKey(t *testing.T) { + // Arrange + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + part, err := writer.CreateFormFile("key", "key") + if err != nil { + t.Fatal(err) + } + _, err = part.Write([]byte("not a valid key")) + if err != nil { + t.Fatal(err) + } + err = writer.Close() + if err != nil { + t.Fatal(err) + } + + req, err := http.NewRequest("PUT", "/key", body) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", writer.FormDataContentType()) + + user := testutils.GenerateUser() + machine := &models.Machine{ + ID: uuid.New(), + UserID: user.ID, + Name: "test", + PublicKey: []byte("old-key"), + } + req = testutils.AddUserContext(req, user) + req = testutils.AddMachineContext(req, machine) + + injector := do.New() + + // Act + rr := httptest.NewRecorder() + router := chi.NewRouter() + router.Put("/key", updateMachineKey(injector)) + router.ServeHTTP(rr, req) + + // Assert + assert.Equal(t, http.StatusBadRequest, rr.Code) +} + // TODO non-happy-paths diff --git a/pkg/web/router/routes/setup.go b/pkg/web/router/routes/setup.go index ec56315..3bc7274 100644 --- a/pkg/web/router/routes/setup.go +++ b/pkg/web/router/routes/setup.go @@ -2,21 +2,20 @@ package routes import ( "errors" - "io/ioutil" + "io" "net/http" "github.com/go-chi/chi" "github.com/jackc/pgx/v5" - "github.com/lestrrat-go/jwx/v2/jwa" - "github.com/lestrrat-go/jwx/v2/jwk" "github.com/rs/zerolog/log" "github.com/samber/do" + "github.com/therealpaulgg/ssh-sync-common/pkg/dto" + pqc "github.com/therealpaulgg/ssh-sync-server/pkg/crypto" "github.com/therealpaulgg/ssh-sync-server/pkg/database/models" "github.com/therealpaulgg/ssh-sync-server/pkg/database/query" "github.com/therealpaulgg/ssh-sync-server/pkg/database/repository" "github.com/therealpaulgg/ssh-sync-server/pkg/web/live" "github.com/therealpaulgg/ssh-sync-server/pkg/web/middleware" - "github.com/therealpaulgg/ssh-sync/pkg/dto" ) func initialSetup(i *do.Injector) http.HandlerFunc { @@ -43,20 +42,14 @@ func initialSetup(i *do.Injector) http.HandlerFunc { w.WriteHeader(http.StatusBadRequest) return } - fileBytes, err := ioutil.ReadAll(file) + defer file.Close() + fileBytes, err := io.ReadAll(file) if err != nil { - log.Err(err).Msg("error reading file") w.WriteHeader(http.StatusInternalServerError) return } - // validate that it is in fact a public key - key, err := jwk.ParseKey(fileBytes, jwk.WithPEM(true)) - if err != nil { - w.WriteHeader(http.StatusBadRequest) - return - } - keyType := key.KeyType() - if keyType != jwa.EC { + if _, err := pqc.ValidatePublicKey(fileBytes); err != nil { + log.Debug().Err(err).Msg("invalid public key") w.WriteHeader(http.StatusBadRequest) return } diff --git a/pkg/web/router/routes/setup_test.go b/pkg/web/router/routes/setup_test.go index 77b512f..3a56c22 100644 --- a/pkg/web/router/routes/setup_test.go +++ b/pkg/web/router/routes/setup_test.go @@ -78,4 +78,67 @@ func TestInitialSetup(t *testing.T) { } } +func TestInitialSetup_MLDSA(t *testing.T) { + // Arrange + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + _ = writer.WriteField("username", "test") + _ = writer.WriteField("machine_name", "mymachine") + pub, _, err := testutils.GenerateMLDSATestKeys() + if err != nil { + t.Fatal(err) + } + pubPEM, err := testutils.EncodeMLDSAToPem(pub) + if err != nil { + t.Fatal(err) + } + part, err := writer.CreateFormFile("key", "key") + if err != nil { + t.Fatal(err) + } + _, err = part.Write(pubPEM) + if err != nil { + t.Fatal(err) + } + err = writer.Close() + if err != nil { + t.Fatal(err) + } + req, err := http.NewRequest("POST", "/", body) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", writer.FormDataContentType()) + injector := do.New() + ctrl := gomock.NewController(t) + mockTx := pgx.NewMockTx(ctrl) + mockTransactionService := query.NewMockTransactionService(ctrl) + mockTransactionService.EXPECT().StartTx(gomock.Any()).Return(mockTx, nil) + mockTransactionService.EXPECT().Commit(mockTx).Return(nil) + do.Provide(injector, func(i *do.Injector) (query.TransactionService, error) { + return mockTransactionService, nil + }) + mockUserRepository := repository.NewMockUserRepository(ctrl) + user := testutils.GenerateUser() + machine := testutils.GenerateMachine() + mockUserRepository.EXPECT().CreateUserTx(gomock.Any(), mockTx).Return(user, nil) + do.Provide(injector, func(i *do.Injector) (repository.UserRepository, error) { + return mockUserRepository, nil + }) + mockMachineRepository := repository.NewMockMachineRepository(ctrl) + mockMachineRepository.EXPECT().CreateMachineTx(gomock.Any(), mockTx).Return(machine, nil) + do.Provide(injector, func(i *do.Injector) (repository.MachineRepository, error) { + return mockMachineRepository, nil + }) + // Act + rr := httptest.NewRecorder() + handler := initialSetup(injector) + handler.ServeHTTP(rr, req) + // Assert + if status := rr.Code; status != http.StatusOK { + t.Errorf("initialSetup returned wrong status code: got %v want %v", + status, http.StatusOK) + } +} + // TODO non-happy-paths diff --git a/pkg/web/router/routes/user.go b/pkg/web/router/routes/user.go index 0a57dcd..ab16e6a 100644 --- a/pkg/web/router/routes/user.go +++ b/pkg/web/router/routes/user.go @@ -8,8 +8,8 @@ import ( "github.com/go-chi/chi" "github.com/samber/do" + "github.com/therealpaulgg/ssh-sync-common/pkg/dto" "github.com/therealpaulgg/ssh-sync-server/pkg/database/repository" - "github.com/therealpaulgg/ssh-sync/pkg/dto" ) func getUser(i *do.Injector) http.HandlerFunc { diff --git a/pkg/web/router/routes/user_test.go b/pkg/web/router/routes/user_test.go index e009c36..42ba778 100644 --- a/pkg/web/router/routes/user_test.go +++ b/pkg/web/router/routes/user_test.go @@ -12,9 +12,9 @@ import ( "github.com/golang/mock/gomock" "github.com/samber/do" "github.com/stretchr/testify/assert" + "github.com/therealpaulgg/ssh-sync-common/pkg/dto" "github.com/therealpaulgg/ssh-sync-server/pkg/database/models" "github.com/therealpaulgg/ssh-sync-server/pkg/database/repository" - "github.com/therealpaulgg/ssh-sync/pkg/dto" ) func TestGetUser(t *testing.T) { diff --git a/pkg/web/testutils/main.go b/pkg/web/testutils/main.go index 54f7c31..17d946f 100644 --- a/pkg/web/testutils/main.go +++ b/pkg/web/testutils/main.go @@ -6,9 +6,14 @@ import ( "crypto/elliptic" "crypto/rand" "crypto/x509" + "encoding/base64" + "encoding/json" "encoding/pem" + "fmt" "net/http" + "time" + "filippo.io/mldsa" "github.com/google/uuid" "github.com/therealpaulgg/ssh-sync-server/pkg/database/models" "github.com/therealpaulgg/ssh-sync-server/pkg/web/middleware/context_keys" @@ -70,3 +75,74 @@ func EncodeToPem(privKey *ecdsa.PrivateKey, pubKey *ecdsa.PublicKey) ([]byte, [] return pubBytes, privBytes, nil } + +// GenerateMLDSATestKeys generates an ML-DSA keypair. +func GenerateMLDSATestKeys() (*mldsa.PublicKey, *mldsa.PrivateKey, error) { + priv, err := mldsa.GenerateKey(mldsa.MLDSA65()) + if err != nil { + return nil, nil, err + } + return priv.PublicKey(), priv, nil +} + +// EncodeMLDSAToPem PEM-encodes an ML-DSA public key. +func EncodeMLDSAToPem(pub *mldsa.PublicKey) ([]byte, error) { + return pem.EncodeToMemory(&pem.Block{ + Type: "MLDSA PUBLIC KEY", + Bytes: pub.Bytes(), + }), nil +} + +// GenerateMLDSATestToken creates and signs a JWT with ML-DSA for testing. +func GenerateMLDSATestToken(username, machine string, priv *mldsa.PrivateKey) (string, error) { + header := `{"alg":"MLDSA","typ":"JWT"}` + now := time.Now() + claims, err := json.Marshal(map[string]interface{}{ + "iss": "github.com/therealpaulgg/ssh-sync", + "iat": now.Add(-1 * time.Minute).Unix(), + "exp": now.Add(2 * time.Minute).Unix(), + "username": username, + "machine": machine, + }) + if err != nil { + return "", fmt.Errorf("failed to marshal JWT claims: %w", err) + } + + h := base64.RawURLEncoding.EncodeToString([]byte(header)) + c := base64.RawURLEncoding.EncodeToString(claims) + signingInput := h + "." + c + + sig, err := priv.Sign(nil, []byte(signingInput), nil) + if err != nil { + return "", fmt.Errorf("failed to sign JWT: %w", err) + } + s := base64.RawURLEncoding.EncodeToString(sig) + return signingInput + "." + s, nil +} + +// GenerateExpiredMLDSATestToken creates an expired ML-DSA JWT for testing. +func GenerateExpiredMLDSATestToken(username, machine string, priv *mldsa.PrivateKey) (string, error) { + header := `{"alg":"MLDSA","typ":"JWT"}` + past := time.Now().Add(-10 * time.Minute) + claims, err := json.Marshal(map[string]interface{}{ + "iss": "github.com/therealpaulgg/ssh-sync", + "iat": past.Unix(), + "exp": past.Add(5 * time.Minute).Unix(), + "username": username, + "machine": machine, + }) + if err != nil { + return "", err + } + + h := base64.RawURLEncoding.EncodeToString([]byte(header)) + c := base64.RawURLEncoding.EncodeToString(claims) + signingInput := h + "." + c + + sig, err := priv.Sign(nil, []byte(signingInput), nil) + if err != nil { + return "", err + } + s := base64.RawURLEncoding.EncodeToString(sig) + return signingInput + "." + s, nil +}