diff --git a/age.go b/age.go index a4c2ad3d..d28aac84 100644 --- a/age.go +++ b/age.go @@ -113,6 +113,8 @@ type Stanza struct { const fileKeySize = 16 const streamNonceSize = 16 +var errNoRecipients = errors.New("no recipients specified") + // Encrypt encrypts a file to one or more recipients. // // Writes to the returned WriteCloser are encrypted and written to dst as an age @@ -122,7 +124,7 @@ const streamNonceSize = 16 // be encrypted and flushed to dst. func Encrypt(dst io.Writer, recipients ...Recipient) (io.WriteCloser, error) { if len(recipients) == 0 { - return nil, errors.New("no recipients specified") + return nil, errNoRecipients } fileKey := make([]byte, fileKeySize) @@ -130,41 +132,73 @@ func Encrypt(dst io.Writer, recipients ...Recipient) (io.WriteCloser, error) { return nil, err } + if err := writeHeader(dst, recipients, fileKey); err != nil { + return nil, fmt.Errorf("failed to write header: %v", err) + } + + nonce := make([]byte, streamNonceSize) + if _, err := rand.Read(nonce); err != nil { + return nil, err + } + if _, err := dst.Write(nonce); err != nil { + return nil, fmt.Errorf("failed to write nonce: %v", err) + } + + return stream.NewWriter(streamKey(fileKey, nonce), dst) +} + +// EncryptReader encrypts a reader to one or more recipients. +// +// It encrypts src into dst as an age file. Every recipient will be able to decrypt the file. +func EncryptReader(dst io.Writer, src io.Reader, recipients ...Recipient) error { + if len(recipients) == 0 { + return errNoRecipients + } + + fileKey := make([]byte, fileKeySize) + if _, err := rand.Read(fileKey); err != nil { + return err + } + + if err := writeHeader(dst, recipients, fileKey); err != nil { + return fmt.Errorf("failed to write header: %v", err) + } + + nonce := make([]byte, streamNonceSize) + if _, err := rand.Read(nonce); err != nil { + return err + } + if _, err := dst.Write(nonce); err != nil { + return fmt.Errorf("failed to write nonce: %v", err) + } + + return stream.Encrypt(dst, src, streamKey(fileKey, nonce)) +} + +func writeHeader(dst io.Writer, recipients []Recipient, fileKey []byte) error { hdr := &format.Header{} var labels []string for i, r := range recipients { stanzas, l, err := wrapWithLabels(r, fileKey) if err != nil { - return nil, fmt.Errorf("failed to wrap key for recipient #%d: %v", i, err) + return fmt.Errorf("failed to wrap key for recipient #%d: %v", i, err) } sort.Strings(l) if i == 0 { labels = l } else if !slicesEqual(labels, l) { - return nil, fmt.Errorf("incompatible recipients") + return fmt.Errorf("incompatible recipients") } for _, s := range stanzas { hdr.Recipients = append(hdr.Recipients, (*format.Stanza)(s)) } } if mac, err := headerMAC(fileKey, hdr); err != nil { - return nil, fmt.Errorf("failed to compute header MAC: %v", err) + return fmt.Errorf("failed to compute header MAC: %v", err) } else { hdr.MAC = mac } - if err := hdr.Marshal(dst); err != nil { - return nil, fmt.Errorf("failed to write header: %v", err) - } - - nonce := make([]byte, streamNonceSize) - if _, err := rand.Read(nonce); err != nil { - return nil, err - } - if _, err := dst.Write(nonce); err != nil { - return nil, fmt.Errorf("failed to write nonce: %v", err) - } - - return stream.NewWriter(streamKey(fileKey, nonce), dst) + return hdr.Marshal(dst) } func wrapWithLabels(r Recipient, fileKey []byte) (s []*Stanza, labels []string, err error) { diff --git a/age_test.go b/age_test.go index 8cf68670..03786590 100644 --- a/age_test.go +++ b/age_test.go @@ -41,6 +41,26 @@ func ExampleEncrypt() { // Encrypted file size: 219 } +func ExampleEncryptReader() { + publicKey := "age1cy0su9fwf3gf9mw868g5yut09p6nytfmmnktexz2ya5uqg9vl9sss4euqm" + recipient, err := age.ParseX25519Recipient(publicKey) + if err != nil { + log.Fatalf("Failed to parse public key %q: %v", publicKey, err) + } + + in := strings.NewReader("Black lives matter.") + out := &bytes.Buffer{} + + err = age.EncryptReader(out, in, recipient) + if err != nil { + log.Fatalf("Failed to create encrypted file: %v", err) + } + + fmt.Printf("Encrypted file size: %d\n", out.Len()) + // Output: + // Encrypted file size: 219 +} + // DO NOT hardcode the private key. Store it in a secret storage solution, // on disk if the local machine is trusted, or have the user provide it. var privateKey string @@ -189,6 +209,65 @@ func TestEncryptDecryptScrypt(t *testing.T) { } } +func TestEncryptReaderX25519(t *testing.T) { + a, err := age.GenerateX25519Identity() + if err != nil { + t.Fatal(err) + } + b, err := age.GenerateX25519Identity() + if err != nil { + t.Fatal(err) + } + buf := &bytes.Buffer{} + err = age.EncryptReader(buf, strings.NewReader(helloWorld), a.Recipient(), b.Recipient()) + if err != nil { + t.Fatal(err) + } + + out, err := age.Decrypt(buf, b) + if err != nil { + t.Fatal(err) + } + outBytes, err := io.ReadAll(out) + if err != nil { + t.Fatal(err) + } + if string(outBytes) != helloWorld { + t.Errorf("wrong data: %q, excepted %q", outBytes, helloWorld) + } +} + +func TestEncryptReaderScrypt(t *testing.T) { + password := "twitch.tv/filosottile" + + r, err := age.NewScryptRecipient(password) + if err != nil { + t.Fatal(err) + } + r.SetWorkFactor(15) + buf := &bytes.Buffer{} + err = age.EncryptReader(buf, strings.NewReader(helloWorld), r) + if err != nil { + t.Fatal(err) + } + + i, err := age.NewScryptIdentity(password) + if err != nil { + t.Fatal(err) + } + out, err := age.Decrypt(buf, i) + if err != nil { + t.Fatal(err) + } + outBytes, err := io.ReadAll(out) + if err != nil { + t.Fatal(err) + } + if string(outBytes) != helloWorld { + t.Errorf("wrong data: %q, excepted %q", outBytes, helloWorld) + } +} + func TestParseIdentities(t *testing.T) { tests := []struct { name string diff --git a/cmd/age/age.go b/cmd/age/age.go index e5d17e2b..aaf41f16 100644 --- a/cmd/age/age.go +++ b/cmd/age/age.go @@ -402,16 +402,10 @@ func encrypt(recipients []age.Recipient, in io.Reader, out io.Writer, withArmor }() out = a } - w, err := age.Encrypt(out, recipients...) + err := age.EncryptReader(out, in, recipients...) if err != nil { errorf("%v", err) } - if _, err := io.Copy(w, in); err != nil { - errorf("%v", err) - } - if err := w.Close(); err != nil { - errorf("%v", err) - } } // crlfMangledIntro and utf16MangledIntro are the intro lines of the age format diff --git a/internal/stream/stream.go b/internal/stream/stream.go index 7551274b..5ae5842b 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -215,16 +215,62 @@ const ( ) func (w *Writer) flushChunk(last bool) error { - if !last && len(w.unwritten) != ChunkSize { - panic("stream: internal error: flush called with partial chunk") + err := writeChunk(w.dst, w.a, &w.nonce, w.unwritten, last) + w.unwritten = w.buf[:0] + return err +} + +func writeChunk(dst io.Writer, aead cipher.AEAD, nonce *[chacha20poly1305.NonceSize]byte, plaintext []byte, last bool) error { + if !last && len(plaintext) != ChunkSize { + panic("stream: internal error: writeChunk called with partial chunk") } if last { - setLastChunkFlag(&w.nonce) + setLastChunkFlag(nonce) } - buf := w.a.Seal(w.buf[:0], w.nonce[:], w.unwritten, nil) - _, err := w.dst.Write(buf) - w.unwritten = w.buf[:0] - incNonce(&w.nonce) + buf := aead.Seal(plaintext[:0], nonce[:], plaintext, nil) + _, err := dst.Write(buf) + incNonce(nonce) return err } + +func Encrypt(dst io.Writer, src io.Reader, key []byte) error { + aead, err := chacha20poly1305.New(key) + if err != nil { + return err + } + + var nonce [chacha20poly1305.NonceSize]byte + var bufs [2][encChunkSize]byte + var hasUnwritten bool + + for unwritten, current := 0, 1; ; unwritten, current = current, unwritten { + n, err := io.ReadFull(src, bufs[current][:ChunkSize]) + if err == io.EOF { + if hasUnwritten { + return writeChunk(dst, aead, &nonce, bufs[unwritten][:ChunkSize], lastChunk) + } else { // empty payload + return writeChunk(dst, aead, &nonce, bufs[current][:0], lastChunk) + } + } else if err == io.ErrUnexpectedEOF { + if hasUnwritten { + err := writeChunk(dst, aead, &nonce, bufs[unwritten][:ChunkSize], notLastChunk) + if err != nil { + return err + } + } + return writeChunk(dst, aead, &nonce, bufs[current][:n], lastChunk) + } else if err != nil { + return err + } + + if hasUnwritten { + err := writeChunk(dst, aead, &nonce, bufs[unwritten][:ChunkSize], notLastChunk) + if err != nil { + return err + } + } else { + hasUnwritten = true + } + } +} diff --git a/internal/stream/stream_test.go b/internal/stream/stream_test.go index 8cac9674..c92f76b1 100644 --- a/internal/stream/stream_test.go +++ b/internal/stream/stream_test.go @@ -8,6 +8,7 @@ import ( "bytes" "crypto/rand" "fmt" + "io" "testing" "filippo.io/age/internal/stream" @@ -91,3 +92,42 @@ func testRoundTrip(t *testing.T, stepSize, length int) { n += nn } } + +func TestEncrypt(t *testing.T) { + for _, mul := range []int{0, 1, 2, 3} { + for _, add := range []int{0, 1, 2, 3, stream.ChunkSize - 1} { + length := mul*stream.ChunkSize + add + + t.Run(fmt.Sprintf("length=%d", length), func(t *testing.T) { + src := make([]byte, length) + if _, err := rand.Read(src); err != nil { + t.Fatal(err) + } + buf := &bytes.Buffer{} + key := make([]byte, chacha20poly1305.KeySize) + if _, err := rand.Read(key); err != nil { + t.Fatal(err) + } + + err := stream.Encrypt(buf, bytes.NewReader(src), key) + if err != nil { + t.Fatal(err) + } + + r, err := stream.NewReader(key, buf) + if err != nil { + t.Fatal(err) + } + + dec, err := io.ReadAll(r) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(src, dec) { + t.Errorf("Wrong decrypted data") + } + }) + } + } +}