Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions schemes/bgv/bgv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,40 @@ func testEncoder(tc *TestContext, t *testing.T) {
require.True(t, slices.Equal(coeffs, have))
})
}

for _, lvl := range testLevel {
lvl := lvl
t.Run(name("Encoder/Poly", tc, lvl), func(t *testing.T) {
t.Parallel()

poly := tc.Sampler.ReadNew()
pt := NewPlaintext(tc.Params, lvl)
err := tc.Ecd.Encode(poly, pt)
require.NoError(t, err)
have := tc.Params.RingT().NewPoly()
err = tc.Ecd.Decode(pt, have)
require.NoError(t, err)
require.True(t, poly.Equal(&have))

// decoding into []uint64 should also work
coeffs := make([]uint64, tc.Params.RingT().N())
err = tc.Ecd.Decode(pt, coeffs)
require.NoError(t, err)
require.True(t, slices.Equal(poly.Coeffs[0], coeffs))
})
}

t.Run(name("Encoder/UnsupportedType", tc, 0), func(t *testing.T) {

pt := NewPlaintext(tc.Params, 0)
err := tc.Ecd.Encode("encoding some string", pt)
require.Error(t, err, "expected error when encoding unsupported type")

pt.IsBatched = false
err = tc.Ecd.Encode("encoding some string", pt)
require.Error(t, err, "expected error when encoding unsupported type")
})

}

func testEvaluatorBvg(tc *TestContext, t *testing.T) {
Expand Down
25 changes: 21 additions & 4 deletions schemes/bgv/encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,22 @@ func (ecd Encoder) GetRLWEParameters() *rlwe.Parameters {
return &ecd.parameters.Parameters
}

// Encode encodes an [IntegerSlice] of size at most n on a pre-allocated plaintext,
// where n is the largest value satisfying PlaintextModulus = 1 mod 2n if pt.IsBatched=true,
// or the value of N set in the parameters otherwise.
// Encode encodes values on a pre-allocated plaintext. The `values` must be of type [IntegerSlice] or be a [ring.Poly] from params.RingT.
// If `values` is of type [ring.Poly], then pt.IsBatched is set to false and the polynomial is encoded directly (i.e., scaled).
// If `values` is of type [IntegerSlice], then the encoding depends respects the pt.IsBatched flag:
// - If pt.IsBatched=false, then values are interpreted as the coefficients of a polynomial and encored as above.
// - If pt.IsBatched=true, then values are encoded in a SIMD fashion on n slots, where n is the largest value satisfying PlaintextModulus = 1 mod 2n.
func (ecd Encoder) Encode(values interface{}, pt *rlwe.Plaintext) (err error) {

poly, isPoly := values.(ring.Poly)
if isPoly {
if poly.N() != ecd.parameters.RingT().N() {
return fmt.Errorf("cannot Encode: poly.N()=%d != RingT.N()=%d", poly.N(), ecd.parameters.RingT().N())
}
pt.IsBatched = false
values = poly.Coeffs[0]
}

if pt.IsBatched {
return ecd.EmbedScale(values, true, pt.MetaData, pt.Value)
} else {
Expand Down Expand Up @@ -182,6 +193,8 @@ func (ecd Encoder) Encode(values interface{}, pt *rlwe.Plaintext) (err error) {
}

valLen = len(values)
default:
return fmt.Errorf("cannot Encode: values.(type) must be either IntegerSlice or ring.Poly but is %T", values)
}

for i := valLen; i < N; i++ {
Expand Down Expand Up @@ -466,7 +479,9 @@ func (ecd Encoder) RingQ2T(level int, scaleDown bool, pQ, pT ring.Poly) {
}
}

// Decode decodes a plaintext on an IntegerSlice mod PlaintextModulus of size at most N, where N is the smallest value satisfying PlaintextModulus = 1 mod 2N.
// Decode decodes a [Plaintext] into values of type [IntegerSlice] or [ring.Poly].
// If pt.IsBatched=true, then values must be a [IntegerSlice] and the plaintext is decoded in a SIMD fashion from n slots, where n is the largest value satisfying PlaintextModulus = 1 mod 2n.
// If pt.IsBatched=false, then values can be either a [ring.Poly] from Parameters.RingT or an [IntegerSlice] and the plaintext is decoded as the coefficients of a polynomial
func (ecd Encoder) Decode(pt *rlwe.Plaintext, values interface{}) (err error) {

var buffT *ring.Poly
Expand Down Expand Up @@ -516,6 +531,8 @@ func (ecd Encoder) Decode(pt *rlwe.Plaintext, values interface{}) (err error) {
values[i] = value
}
}
case ring.Poly:
copy(values.Coeffs[0], buffT.Coeffs[0])

default:
return fmt.Errorf("cannot Decode: values must be either []uint64 or []int64 but is %T", values)
Expand Down
4 changes: 2 additions & 2 deletions schemes/bgv/test_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ func VerifyTestVectors(params Parameters, encoder *Encoder, decryptor *rlwe.Decr
t.Error("invalid unsupported test object type")
}

fmt.Println("have", values[:10])
fmt.Println("want", want[:10])
// fmt.Println("have", values[:10])
// fmt.Println("want", want[:10])
Comment on lines +90 to +91
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// fmt.Println("have", values[:10])
// fmt.Println("want", want[:10])

require.True(t, slices.Equal(values, want))
}

Expand Down
12 changes: 12 additions & 0 deletions schemes/ckks/ckks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,18 @@ func testEncoder(tc *TestContext, t *testing.T) {

require.GreaterOrEqual(t, math.Log2(1/meanprec), minPrec)
})

t.Run(name("Encoder/UnsupportedType", tc), func(t *testing.T) {

pt := NewPlaintext(tc.Params, 0)
err := tc.Ecd.Encode("encoding some string", pt)
require.Error(t, err, "expected error when encoding unsupported type")

pt.IsBatched = false
err = tc.Ecd.Encode("encoding some string", pt)
require.Error(t, err, "expected error when encoding unsupported type")
})

}

func testEvaluatorAdd(tc *TestContext, t *testing.T) {
Expand Down
Loading