Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 51 additions & 17 deletions age.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ type Stanza struct {
const fileKeySize = 16
const streamNonceSize = 16

var errNoRecipients = errors.New("no recipients specified")

// Encrypt encrypts a file to one or more recipients.
//
// Writes to the returned WriteCloser are encrypted and written to dst as an age
Expand All @@ -122,49 +124,81 @@ const streamNonceSize = 16
// be encrypted and flushed to dst.
func Encrypt(dst io.Writer, recipients ...Recipient) (io.WriteCloser, error) {
if len(recipients) == 0 {
return nil, errors.New("no recipients specified")
return nil, errNoRecipients
}

fileKey := make([]byte, fileKeySize)
if _, err := rand.Read(fileKey); err != nil {
return nil, err
}

if err := writeHeader(dst, recipients, fileKey); err != nil {
return nil, fmt.Errorf("failed to write header: %v", err)
}

nonce := make([]byte, streamNonceSize)
if _, err := rand.Read(nonce); err != nil {
return nil, err
}
if _, err := dst.Write(nonce); err != nil {
return nil, fmt.Errorf("failed to write nonce: %v", err)
}

return stream.NewWriter(streamKey(fileKey, nonce), dst)
}

// EncryptReader encrypts a reader to one or more recipients.
//
// It encrypts src into dst as an age file. Every recipient will be able to decrypt the file.
func EncryptReader(dst io.Writer, src io.Reader, recipients ...Recipient) error {
if len(recipients) == 0 {
return errNoRecipients
}

fileKey := make([]byte, fileKeySize)
if _, err := rand.Read(fileKey); err != nil {
return err
}

if err := writeHeader(dst, recipients, fileKey); err != nil {
return fmt.Errorf("failed to write header: %v", err)
}

nonce := make([]byte, streamNonceSize)
if _, err := rand.Read(nonce); err != nil {
return err
}
if _, err := dst.Write(nonce); err != nil {
return fmt.Errorf("failed to write nonce: %v", err)
}

return stream.Encrypt(dst, src, streamKey(fileKey, nonce))
}

func writeHeader(dst io.Writer, recipients []Recipient, fileKey []byte) error {
hdr := &format.Header{}
var labels []string
for i, r := range recipients {
stanzas, l, err := wrapWithLabels(r, fileKey)
if err != nil {
return nil, fmt.Errorf("failed to wrap key for recipient #%d: %v", i, err)
return fmt.Errorf("failed to wrap key for recipient #%d: %v", i, err)
}
sort.Strings(l)
if i == 0 {
labels = l
} else if !slicesEqual(labels, l) {
return nil, fmt.Errorf("incompatible recipients")
return fmt.Errorf("incompatible recipients")
}
for _, s := range stanzas {
hdr.Recipients = append(hdr.Recipients, (*format.Stanza)(s))
}
}
if mac, err := headerMAC(fileKey, hdr); err != nil {
return nil, fmt.Errorf("failed to compute header MAC: %v", err)
return fmt.Errorf("failed to compute header MAC: %v", err)
} else {
hdr.MAC = mac
}
if err := hdr.Marshal(dst); err != nil {
return nil, fmt.Errorf("failed to write header: %v", err)
}

nonce := make([]byte, streamNonceSize)
if _, err := rand.Read(nonce); err != nil {
return nil, err
}
if _, err := dst.Write(nonce); err != nil {
return nil, fmt.Errorf("failed to write nonce: %v", err)
}

return stream.NewWriter(streamKey(fileKey, nonce), dst)
return hdr.Marshal(dst)
}

func wrapWithLabels(r Recipient, fileKey []byte) (s []*Stanza, labels []string, err error) {
Expand Down
79 changes: 79 additions & 0 deletions age_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,26 @@ func ExampleEncrypt() {
// Encrypted file size: 219
}

func ExampleEncryptReader() {
publicKey := "age1cy0su9fwf3gf9mw868g5yut09p6nytfmmnktexz2ya5uqg9vl9sss4euqm"
recipient, err := age.ParseX25519Recipient(publicKey)
if err != nil {
log.Fatalf("Failed to parse public key %q: %v", publicKey, err)
}

in := strings.NewReader("Black lives matter.")
out := &bytes.Buffer{}

err = age.EncryptReader(out, in, recipient)
if err != nil {
log.Fatalf("Failed to create encrypted file: %v", err)
}

fmt.Printf("Encrypted file size: %d\n", out.Len())
// Output:
// Encrypted file size: 219
}

// DO NOT hardcode the private key. Store it in a secret storage solution,
// on disk if the local machine is trusted, or have the user provide it.
var privateKey string
Expand Down Expand Up @@ -189,6 +209,65 @@ func TestEncryptDecryptScrypt(t *testing.T) {
}
}

func TestEncryptReaderX25519(t *testing.T) {
a, err := age.GenerateX25519Identity()
if err != nil {
t.Fatal(err)
}
b, err := age.GenerateX25519Identity()
if err != nil {
t.Fatal(err)
}
buf := &bytes.Buffer{}
err = age.EncryptReader(buf, strings.NewReader(helloWorld), a.Recipient(), b.Recipient())
if err != nil {
t.Fatal(err)
}

out, err := age.Decrypt(buf, b)
if err != nil {
t.Fatal(err)
}
outBytes, err := io.ReadAll(out)
if err != nil {
t.Fatal(err)
}
if string(outBytes) != helloWorld {
t.Errorf("wrong data: %q, excepted %q", outBytes, helloWorld)
}
}

func TestEncryptReaderScrypt(t *testing.T) {
password := "twitch.tv/filosottile"

r, err := age.NewScryptRecipient(password)
if err != nil {
t.Fatal(err)
}
r.SetWorkFactor(15)
buf := &bytes.Buffer{}
err = age.EncryptReader(buf, strings.NewReader(helloWorld), r)
if err != nil {
t.Fatal(err)
}

i, err := age.NewScryptIdentity(password)
if err != nil {
t.Fatal(err)
}
out, err := age.Decrypt(buf, i)
if err != nil {
t.Fatal(err)
}
outBytes, err := io.ReadAll(out)
if err != nil {
t.Fatal(err)
}
if string(outBytes) != helloWorld {
t.Errorf("wrong data: %q, excepted %q", outBytes, helloWorld)
}
}

func TestParseIdentities(t *testing.T) {
tests := []struct {
name string
Expand Down
8 changes: 1 addition & 7 deletions cmd/age/age.go
Original file line number Diff line number Diff line change
Expand Up @@ -402,16 +402,10 @@ func encrypt(recipients []age.Recipient, in io.Reader, out io.Writer, withArmor
}()
out = a
}
w, err := age.Encrypt(out, recipients...)
err := age.EncryptReader(out, in, recipients...)
if err != nil {
errorf("%v", err)
}
if _, err := io.Copy(w, in); err != nil {
errorf("%v", err)
}
if err := w.Close(); err != nil {
errorf("%v", err)
}
}

// crlfMangledIntro and utf16MangledIntro are the intro lines of the age format
Expand Down
60 changes: 53 additions & 7 deletions internal/stream/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,16 +215,62 @@ const (
)

func (w *Writer) flushChunk(last bool) error {
if !last && len(w.unwritten) != ChunkSize {
panic("stream: internal error: flush called with partial chunk")
err := writeChunk(w.dst, w.a, &w.nonce, w.unwritten, last)
w.unwritten = w.buf[:0]
return err
}

func writeChunk(dst io.Writer, aead cipher.AEAD, nonce *[chacha20poly1305.NonceSize]byte, plaintext []byte, last bool) error {
if !last && len(plaintext) != ChunkSize {
panic("stream: internal error: writeChunk called with partial chunk")
}

if last {
setLastChunkFlag(&w.nonce)
setLastChunkFlag(nonce)
}
buf := w.a.Seal(w.buf[:0], w.nonce[:], w.unwritten, nil)
_, err := w.dst.Write(buf)
w.unwritten = w.buf[:0]
incNonce(&w.nonce)
buf := aead.Seal(plaintext[:0], nonce[:], plaintext, nil)
_, err := dst.Write(buf)
incNonce(nonce)
return err
}

func Encrypt(dst io.Writer, src io.Reader, key []byte) error {
aead, err := chacha20poly1305.New(key)
if err != nil {
return err
}

var nonce [chacha20poly1305.NonceSize]byte
var bufs [2][encChunkSize]byte
var hasUnwritten bool

for unwritten, current := 0, 1; ; unwritten, current = current, unwritten {
n, err := io.ReadFull(src, bufs[current][:ChunkSize])
if err == io.EOF {
if hasUnwritten {
return writeChunk(dst, aead, &nonce, bufs[unwritten][:ChunkSize], lastChunk)
} else { // empty payload
return writeChunk(dst, aead, &nonce, bufs[current][:0], lastChunk)
}
} else if err == io.ErrUnexpectedEOF {
if hasUnwritten {
err := writeChunk(dst, aead, &nonce, bufs[unwritten][:ChunkSize], notLastChunk)
if err != nil {
return err
}
}
return writeChunk(dst, aead, &nonce, bufs[current][:n], lastChunk)
} else if err != nil {
return err
}

if hasUnwritten {
err := writeChunk(dst, aead, &nonce, bufs[unwritten][:ChunkSize], notLastChunk)
if err != nil {
return err
}
} else {
hasUnwritten = true
}
}
}
40 changes: 40 additions & 0 deletions internal/stream/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"bytes"
"crypto/rand"
"fmt"
"io"
"testing"

"filippo.io/age/internal/stream"
Expand Down Expand Up @@ -91,3 +92,42 @@ func testRoundTrip(t *testing.T, stepSize, length int) {
n += nn
}
}

func TestEncrypt(t *testing.T) {
for _, mul := range []int{0, 1, 2, 3} {
for _, add := range []int{0, 1, 2, 3, stream.ChunkSize - 1} {
length := mul*stream.ChunkSize + add

t.Run(fmt.Sprintf("length=%d", length), func(t *testing.T) {
src := make([]byte, length)
if _, err := rand.Read(src); err != nil {
t.Fatal(err)
}
buf := &bytes.Buffer{}
key := make([]byte, chacha20poly1305.KeySize)
if _, err := rand.Read(key); err != nil {
t.Fatal(err)
}

err := stream.Encrypt(buf, bytes.NewReader(src), key)
if err != nil {
t.Fatal(err)
}

r, err := stream.NewReader(key, buf)
if err != nil {
t.Fatal(err)
}

dec, err := io.ReadAll(r)
if err != nil {
t.Fatal(err)
}

if !bytes.Equal(src, dec) {
t.Errorf("Wrong decrypted data")
}
})
}
}
}
Loading