From df182800bb6e58c1cbee39fa86b375d4b759bc10 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Mon, 30 Aug 2021 12:27:18 -0700 Subject: [PATCH] Performance Optimizations (#5) Enabling Lazy NTT, giving a speedup of about 2x for encryption/decryption/rerandomization. * Updated version to 0.2.0 * added poly operation benchmarks * added butterfly mod and tests * cleaned up interfaces. * basic lazy butterfly tests work * lazy butterfly and tests * added trait and tests for lazy transforms * swap * switch to use split_at * added lazy (i)ntt; tests pass. * conditional compilation * benchmarking * optimized performance by pure u64 impl * added benchmark for integer ops and gaussian sampling * sample ternary * optimized uniform sampling * clean up * enabled lazy ntt by default * fixing reviewer comments * fix underflow * further addressing comments * updating version to 0.2.1 * remove unused dependencies Co-authored-by: Hao Chen --- .gitignore | 1 + CHANGELOG.md | 8 + Cargo.toml | 20 +- README.md | 2 +- benches/butterfly.rs | 73 +++++++ benches/example.rs | 270 -------------------------- benches/integerops.rs | 27 +++ benches/polyops.rs | 74 ++++++++ benches/scheme.rs | 104 ++++++++++ examples/basic.rs | 1 - examples/rerandomize.rs | 2 +- src/integer_arith/butterfly.rs | 241 +++++++++++++++++++++++ src/integer_arith/mod.rs | 17 +- src/integer_arith/scalar.rs | 139 +++++++++++--- src/integer_arith/util.rs | 38 ++++ src/lib.rs | 90 +++++---- src/polyarith/lazy_ntt.rs | 46 +++++ src/polyarith/mod.rs | 1 + src/randutils.rs | 97 ++++++++++ src/rqpoly.rs | 338 +++++++++++++++------------------ src/serialize.rs | 12 +- src/traits.rs | 17 ++ 22 files changed, 1086 insertions(+), 532 deletions(-) create mode 100644 CHANGELOG.md create mode 100644 benches/butterfly.rs delete mode 100644 benches/example.rs create mode 100644 benches/integerops.rs create mode 100644 benches/polyops.rs create mode 100644 benches/scheme.rs create mode 100644 src/integer_arith/butterfly.rs create mode 100644 src/integer_arith/util.rs create mode 100644 src/polyarith/lazy_ntt.rs create mode 100644 src/polyarith/mod.rs create mode 100644 src/randutils.rs diff --git a/.gitignore b/.gitignore index 96ef6c0..5edd8e0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /target Cargo.lock +src/*.md diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..eaacede --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,8 @@ +## 0.2.1 (August 30, 2021) + +* Performance optimizations: Faster encryption/decryption based on lazy NTT + +## 0.2.0 (June 1, 2021) + +* Added serailization support +* Added ability to customize plaintext modulus diff --git a/Cargo.toml b/Cargo.toml index c47a822..b8ccf53 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "Cupcake" -version = "0.1.1" +version = "0.2.1" authors = ["Hao Chen "] license = "MIT" edition = "2018" @@ -20,6 +20,9 @@ bencher = "0.1.5" name = "cupcake" path = "src/lib.rs" +[features] +bench = [] + [[example]] name = "basic" @@ -27,5 +30,18 @@ name = "basic" name = "serialization" [[bench]] -name = "example" +name = "scheme" +harness = false + +[[bench]] +name = "polyops" +harness = false +required-features = ["bench"] + +[[bench]] +name = "butterfly" +harness = false + +[[bench]] +name = "integerops" harness = false diff --git a/README.md b/README.md index 7e48b3f..5c0e435 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ Cupcake requires or works with ## Installation Add the following line to the dependencies of your Cargo.toml: ``` -Cupcake = "0.1.1" +Cupcake = "0.2.1" ``` ## Building from source diff --git a/benches/butterfly.rs b/benches/butterfly.rs new file mode 100644 index 0000000..a94d41b --- /dev/null +++ b/benches/butterfly.rs @@ -0,0 +1,73 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. +#[macro_use] +extern crate bencher; +use bencher::Bencher; +use cupcake::integer_arith::butterfly::{ + butterfly, inverse_butterfly, lazy_butterfly, lazy_butterfly_u64, +}; +use cupcake::integer_arith::scalar::Scalar; +use cupcake::integer_arith::ArithUtils; + +#[allow(non_snake_case)] +fn bench_butterfly(bench: &mut Bencher) { + let q = Scalar::new_modulus(18014398492704769u64); + let x = rand::random::(); + let y = rand::random::(); + let w = rand::random::(); + + let mut X = Scalar::from(x); + let mut Y = Scalar::from(y); + let W = Scalar::from(w); + + bench.iter(|| { + let _ = butterfly(&mut X, &mut Y, &W, &q); + }) +} + +#[allow(non_snake_case)] +fn bench_inverse_butterfly(bench: &mut Bencher) { + let q = Scalar::new_modulus(18014398492704769u64); + let x = rand::random::(); + let y = rand::random::(); + let w = rand::random::(); + + let mut X = Scalar::from(x); + let mut Y = Scalar::from(y); + let W = Scalar::from(w); + + bench.iter(|| { + let _ = inverse_butterfly(&mut X, &mut Y, &W, &q); + }) +} + +#[allow(non_snake_case)] +fn bench_lazy_butterfly(bench: &mut Bencher) { + let q = Scalar::new_modulus(18014398492704769u64); + let x = rand::random::(); + let y = rand::random::(); + let w = rand::random::(); + + let mut X = Scalar::from(x); + let mut Y = Scalar::from(y); + let W = Scalar::from(w); + + let Wprime: u64 = cupcake::integer_arith::util::compute_harvey_ratio(W.rep(), q.rep()); + + let twoq: u64 = q.rep() << 1; + + bench.iter(|| { + let _ = lazy_butterfly_u64(x, y, W.rep(), Wprime, q.rep(), twoq); + }) +} + +benchmark_group!( + butterfly_group, + bench_butterfly, + bench_inverse_butterfly, + bench_lazy_butterfly +); + +benchmark_main!(butterfly_group); diff --git a/benches/example.rs b/benches/example.rs deleted file mode 100644 index 7d55275..0000000 --- a/benches/example.rs +++ /dev/null @@ -1,270 +0,0 @@ -// Copyright (c) Facebook, Inc. and its affiliates. -// -// This source code is licensed under the MIT license found in the -// LICENSE file in the root directory of this source tree. -#[macro_use] -extern crate bencher; - -use bencher::Bencher; -use cupcake::traits::*; -// use cupcake::rqpoly::{RqPoly, RqPolyContext}; -// use rand::rngs::StdRng; -// use rand::FromEntropy; - -// fn mod_mul(bench: &mut Bencher) { -// let fv = FV::::default_2048(); -// let a = BigInt::sample_below(&fv.q); -// let b = BigInt::sample_below(&fv.q); -// bench.iter(|| { -// let _ = BigInt::mod_mul(&a, &b, &fv.q); -// }) -// } - -// fn scalar_ntt(bench: &mut Bencher) { -// let q = Scalar::new_modulus(18014398492704769u64); -// let context = Arc::new(RqPolyContext::new(2048, &q)); -// let mut testpoly = cupcake::randutils::sample_uniform_poly(context.clone()); - -// bench.iter(|| { -// testpoly.is_ntt_form = false; -// let _ = testpoly.forward_transform(); -// }) -// } - -// fn scalar_intt(bench: &mut Bencher) { -// let q = Scalar::new_modulus(18014398492704769u64); -// let context = Arc::new(RqPolyContext::new(2048, &q)); - -// let mut testpoly = cupcake::randutils::sample_uniform_poly(context.clone()); - -// bench.iter(|| { -// testpoly.is_ntt_form = true; -// let _ = testpoly.inverse_transform(); -// }) -// } - -// fn mod_mul_fast(bench: &mut Bencher){ -// let q = Scalar::new_modulus(18014398492704769); - -// let ratio = (17592185012223u64, 1024u64); -// let mut a = Scalar::sample_blw(&q); -// let mut b = Scalar::sample_blw(&q); -// a = Scalar::modulus(&a, &q); -// b = Scalar::modulus(&b, &q); - -// bench.iter(|| { -// let _ = Scalar::barret_multiply(&a, &b, ratio, q.rep); -// }) -// } - -// fn mod_mul_fast_wrap(bench: &mut Bencher) { -// let q = Scalar::from_u64_raw(18014398492704769); - -// let ratio = (17592185012223u64, 1024u64); -// let mut a = Scalar::sample_blw(&q); -// let mut b = Scalar::sample_blw(&q); -// a = Scalar::modulus(&a, &q); -// b = Scalar::modulus(&b, &q); - -// bench.iter(|| { -// let _ = Scalar::mul_mod(&a, &b, &q); -// }) -// } - -// fn ntt_multiply(bench: &mut Bencher) { -// let fv = FV::::default_2048(); -// // let context = RqPolyContext::new(fv.n, &fv.q); -// // fv.context = Arc::new(context); -// let a = fv.sample_uniform_poly(); -// let b = fv.sample_uniform_poly(); -// bench.iter(|| { -// let _ = a.multiply_fast(&b); -// }) -// } - -// fn bigint_ntt(bench: &mut Bencher) { -// let fv = FV::::default_2048(); -// // let context = RqPolyContext::new(fv.n, &fv.q); -// // fv.context = Arc::new(context); -// let mut a = fv.sample_uniform_poly(); -// bench.iter(|| { -// a.is_ntt_form = false; -// let _ = a.forward_transform(); -// }) -// } - -// fn bigint_intt(bench: &mut Bencher) { -// let fv = FV::::default_2048(); -// let mut a = fv.sample_uniform_poly(); -// bench.iter(|| { -// a.is_ntt_form = true; -// let _ = a.inverse_transform(); -// }) -// } - -// fn sample_uniform(bench: &mut Bencher) { -// let fv = FV::::default_2048(); - -// bench.iter(|| { -// let _ = randutils::sample_uniform_poly(fv.context.clone()); -// }) -// } - -// fn sample_gaussian(bench: &mut Bencher) { -// let fv = FV::::default_2048(); - -// bench.iter(|| { -// let _ = fv.sample_gaussian_poly(fv.stdev); -// }) -// } - -// fn sample_uniform_scalar(bench: &mut Bencher) { -// let fv = FV::::default_2048(); - -// bench.iter(|| { -// let _ = Scalar::sample_blw(&fv.q); -// }) -// } - -// fn sample_binary(bench: &mut Bencher) { -// let fv = FV::::default_2048(); - -// bench.iter(|| { -// let _ = fv.sample_binary_poly(); -// }) -// } - -// fn sample_binary_prng(bench: &mut Bencher) { -// let fv = FV::::default_2048(); - -// bench.iter(|| { -// let _ = fv.sample_binary_poly_prng(); -// }) -// } - -// fn sample_uniform_scalar_from_rng(bench: &mut Bencher) { -// let fv = FV::::default_2048(); -// let mut rng = StdRng::from_entropy(); -// bench.iter(|| { -// let _ = Scalar::sample_below_from_rng(&fv.q, &mut rng); -// }) -// } - -// fn sample_from_rng(bench: &mut Bencher) { -// let fv = FV::::default_2048(); -// let mut rng = StdRng::from_entropy(); -// bench.iter(|| { -// let _ = Scalar::_sample_form_rng(fv.q.bit_count, &mut rng); -// }) -// } - -fn encrypt_sk(bench: &mut Bencher) { - let fv = cupcake::default(); - - let sk = fv.generate_key(); - - let mut v = vec![0; fv.n]; - for i in 0..fv.n { - v[i] = i as u8; - } - bench.iter(|| { - let _ = fv.encrypt_sk(&v, &sk); - }) -} - -fn decryption(bench: &mut Bencher) { - let fv = cupcake::default(); - - let sk = fv.generate_key(); - let mut v = vec![0; fv.n]; - for i in 0..fv.n { - v[i] = i as u8; - } - let ct = fv.encrypt_sk(&v, &sk); - bench.iter(|| { - let _ = fv.decrypt(&ct, &sk); - }) -} - -fn encrypt_pk(bench: &mut Bencher) { - let fv = cupcake::default(); - - let (pk, _sk) = fv.generate_keypair(); - let mut v = vec![0; fv.n]; - for i in 0..fv.n { - v[i] = i as u8; - } - bench.iter(|| { - let _ = fv.encrypt(&v, &pk); - }) -} - -fn encrypt_zero_pk(bench: &mut Bencher) { - let fv = cupcake::default(); - - let (pk, _sk) = fv.generate_keypair(); - let mut v = vec![0; fv.n]; - for i in 0..fv.n { - v[i] = i as u8; - } - bench.iter(|| { - let _ = fv.encrypt_zero(&pk); - }) -} - -fn homomorphic_addition(bench: &mut Bencher) { - let fv = cupcake::default(); - - let sk = fv.generate_key(); - - let mut v = vec![0; fv.n]; - for i in 0..fv.n { - v[i] = i as u8; - } - let mut ct1 = fv.encrypt_sk(&v, &sk); - let ct2 = fv.encrypt_sk(&v, &sk); - bench.iter(|| { - fv.add_inplace(&mut ct1, &ct2); - }) -} - -fn rerandomize(bench: &mut Bencher) { - let fv = cupcake::default(); - - let (pk, _) = fv.generate_keypair(); - let mut v = vec![0; fv.n]; - for i in 0..fv.n { - v[i] = i as u8; - } - let mut ct = fv.encrypt(&v, &pk); - - bench.iter(|| { - fv.rerandomize(&mut ct, &pk); - }) -} - -// benchmark_group!( -// scalarop, -// sample_uniform_scalar, -// sample_uniform_scalar_from_rng -// ); -// benchmark_group!( -// polyop, -// sample_uniform, -// sample_gaussian, -// sample_binary, -// scalar_ntt, -// scalar_intt, -// sample_binary_prng -// ); -benchmark_group!( - fvop, - encrypt_sk, - encrypt_pk, - encrypt_zero_pk, - decryption, - homomorphic_addition, - rerandomize, -); - -benchmark_main!(fvop); diff --git a/benches/integerops.rs b/benches/integerops.rs new file mode 100644 index 0000000..4903f8e --- /dev/null +++ b/benches/integerops.rs @@ -0,0 +1,27 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. +#[macro_use] +extern crate bencher; +use bencher::Bencher; +use cupcake::integer_arith::scalar::Scalar; +use cupcake::integer_arith::ArithUtils; + +#[allow(non_snake_case)] +fn bench_mulmod(bench: &mut Bencher) { + let q = Scalar::new_modulus(18014398492704769u64); + let x = rand::random::(); + let y = rand::random::(); + + let X = Scalar::from(x); + let Y = Scalar::from(y); + + bench.iter(|| { + let _ = Scalar::mul_mod(&X, &Y, &q); + }) +} + +benchmark_group!(integerops_group, bench_mulmod,); + +benchmark_main!(integerops_group); diff --git a/benches/polyops.rs b/benches/polyops.rs new file mode 100644 index 0000000..7698803 --- /dev/null +++ b/benches/polyops.rs @@ -0,0 +1,74 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. +#[macro_use] +extern crate bencher; +use bencher::Bencher; +use cupcake::rqpoly::RqPolyContext; +use cupcake::traits::*; +pub use std::sync::Arc; +use cupcake::integer_arith::scalar::Scalar; +use cupcake::integer_arith::ArithUtils; +use cupcake::randutils; + +fn scalar_ntt(bench: &mut Bencher) { + let q = Scalar::new_modulus(18014398492704769u64); + let context = Arc::new(RqPolyContext::new(2048, &q)); + let mut testpoly = randutils::sample_uniform_poly(context); + + bench.iter(|| { + testpoly.is_ntt_form = false; + let _ = testpoly.forward_transform(); + }) +} + +fn scalar_intt(bench: &mut Bencher) { + let q = Scalar::new_modulus(18014398492704769u64); + let context = Arc::new(RqPolyContext::new(2048, &q)); + + let mut testpoly = cupcake::randutils::sample_uniform_poly(context.clone()); + + bench.iter(|| { + testpoly.is_ntt_form = true; + let _ = testpoly.inverse_transform(); + }) +} + +fn sample_uniform(bench: &mut Bencher) { + let q = Scalar::new_modulus(18014398492704769u64); + let context = Arc::new(RqPolyContext::new(2048, &q)); + + bench.iter(|| { + let _ = randutils::sample_uniform_poly(context.clone()); + }) +} + +fn sample_gaussian(bench: &mut Bencher) { + let q = Scalar::new_modulus(18014398492704769u64); + let context = Arc::new(RqPolyContext::new(2048, &q)); + + bench.iter(|| { + let _ = randutils::sample_gaussian_poly(context.clone(), 3.14); + }) +} + +fn sample_ternary(bench: &mut Bencher) { + let q = Scalar::new_modulus(18014398492704769u64); + let context = Arc::new(RqPolyContext::new(2048, &q)); + + bench.iter(|| { + let _ = randutils::sample_ternary_poly_prng(context.clone()); + }) +} + +benchmark_group!( + polyops, + sample_gaussian, + sample_ternary, + sample_uniform, + scalar_ntt, + scalar_intt, +); + +benchmark_main!(polyops); diff --git a/benches/scheme.rs b/benches/scheme.rs new file mode 100644 index 0000000..883cc63 --- /dev/null +++ b/benches/scheme.rs @@ -0,0 +1,104 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. +#[macro_use] +extern crate bencher; +use bencher::Bencher; +use cupcake::traits::*; +pub use std::sync::Arc; + +fn encrypt_sk(bench: &mut Bencher) { + let fv = cupcake::default(); + + let sk = fv.generate_key(); + + let v = (0..fv.n).map(|i| i as u8).collect::>(); + + bench.iter(|| { + let _ = fv.encrypt_sk(&v, &sk); + }) +} + +fn decryption(bench: &mut Bencher) { + let fv = cupcake::default(); + + let sk = fv.generate_key(); + let mut v = vec![0; fv.n]; + for i in 0..fv.n { + v[i] = i as u8; + } + let ct = fv.encrypt_sk(&v, &sk); + bench.iter(|| { + let _: Vec = fv.decrypt(&ct, &sk); + }) +} + +fn encrypt_pk(bench: &mut Bencher) { + let fv = cupcake::default(); + + let (pk, _sk) = fv.generate_keypair(); + let mut v = vec![0; fv.n]; + for i in 0..fv.n { + v[i] = i as u8; + } + bench.iter(|| { + let _ = fv.encrypt(&v, &pk); + }) +} + +fn encrypt_zero_pk(bench: &mut Bencher) { + let fv = cupcake::default(); + + let (pk, _sk) = fv.generate_keypair(); + let mut v = vec![0; fv.n]; + for i in 0..fv.n { + v[i] = i as u8; + } + bench.iter(|| { + let _ = fv.encrypt_zero(&pk); + }) +} + +fn homomorphic_addition(bench: &mut Bencher) { + let fv = cupcake::default(); + + let sk = fv.generate_key(); + + let mut v = vec![0; fv.n]; + for i in 0..fv.n { + v[i] = i as u8; + } + let mut ct1 = fv.encrypt_sk(&v, &sk); + let ct2 = fv.encrypt_sk(&v, &sk); + bench.iter(|| { + fv.add_inplace(&mut ct1, &ct2); + }) +} + +fn rerandomize(bench: &mut Bencher) { + let fv = cupcake::default(); + + let (pk, _) = fv.generate_keypair(); + let mut v = vec![0; fv.n]; + for i in 0..fv.n { + v[i] = i as u8; + } + let mut ct = fv.encrypt(&v, &pk); + + bench.iter(|| { + fv.rerandomize(&mut ct, &pk); + }) +} + +benchmark_group!( + scheme, + encrypt_sk, + encrypt_pk, + encrypt_zero_pk, + decryption, + homomorphic_addition, + rerandomize, +); + +benchmark_main!(scheme); diff --git a/examples/basic.rs b/examples/basic.rs index 199506b..e8c8a06 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -24,7 +24,6 @@ fn main() { print!("Encrypting a constant vector w of 2s..."); let w = vec![2; fv.n]; - let ctw = fv.encrypt(&w, &pk); pt_actual = fv.decrypt(&ctw, &sk); diff --git a/examples/rerandomize.rs b/examples/rerandomize.rs index 3d3bfaa..530c654 100644 --- a/examples/rerandomize.rs +++ b/examples/rerandomize.rs @@ -2,7 +2,7 @@ // // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -use cupcake::traits::{AdditiveHomomorphicScheme, PKEncryption, SKEncryption, KeyGeneration}; +use cupcake::traits::{AdditiveHomomorphicScheme, KeyGeneration, PKEncryption, SKEncryption}; fn smartprint(v: &Vec) { println!("[{:?}, {:?}, ..., {:?}]", v[0], v[1], v[v.len() - 1]); diff --git a/src/integer_arith/butterfly.rs b/src/integer_arith/butterfly.rs new file mode 100644 index 0000000..a4edb10 --- /dev/null +++ b/src/integer_arith/butterfly.rs @@ -0,0 +1,241 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#[cfg(test)] +use super::{SuperTrait}; + +use super::{ArithUtils}; + +// (X, Y) -> (X+Y, W(X-Y)) mod q +#[allow(non_snake_case)] +pub fn inverse_butterfly(X: &mut T, Y: &mut T, W: &T, q: &T) where T: ArithUtils{ + let temp = T::sub_mod(X,Y, q); + *X = T::add_mod(X, Y, q); + *Y = T::mul_mod(W, &temp, q); +} + +// (X, Y) -> (X+WY, X-WY) mod q +#[allow(non_snake_case)] +pub fn butterfly(X: &mut T, Y: &mut T, W: &T, q: &T) where T: ArithUtils{ + let temp = T::mul_mod(Y, W, q); + *Y = T::sub_mod(X, &temp, q); + *X = T::add_mod(X, &temp, q); +} + +// (X, Y) -> (X+WY, X-WY) +// 0 <= X, Y < 4q => (0 <= X', Y' < 4q) +#[allow(non_snake_case)] +#[cfg(test)] +pub fn lazy_butterfly(X: &mut T, Y: &mut T, W: u64, Wprime: u64, q: &T, twoq: u64) where T: SuperTrait{ + let mut xx = X.rep(); + if xx > twoq{ + xx -= twoq; + } + let _qq = super::util::mul_high_word(Wprime, Y.rep()); + let quo = W.wrapping_mul(Y.rep()) - _qq.wrapping_mul(q.rep()); + // X += quo; + *X = T::from(xx + quo); + // Y = (x + 2q - quo); + *Y = T::from(xx + twoq - quo); +} + +#[allow(clippy::many_single_char_names)] +pub fn lazy_butterfly_u64(mut x: u64, y:u64, w: u64, wprime: u64, q: u64, twoq: u64) -> (u64, u64){ + // let twoq = 0; + if x > twoq{ + x -= twoq; + } + let _qq = super::util::mul_high_word(wprime, y); + let wy = w.wrapping_mul(y); + let qqq = _qq.wrapping_mul(q); + let quo; + if wy >= qqq { + quo = wy - qqq; + } + else{ + quo = u64::MAX - qqq + wy + 1; + } + (x + quo, x + twoq - quo) +} + +#[allow(clippy::many_single_char_names)] +pub fn lazy_inverse_butterfly_u64(x: u64, y:u64, w: u64, wprime: u64, q: u64, twoq: u64) -> (u64, u64){ + let mut xx = x+y; + + if xx > twoq { + xx -= twoq; + } + let t = twoq - y + x; + let quo = super::util::mul_high_word(wprime, t); + let wt = w.wrapping_mul(t); + let qquo = quo.wrapping_mul(q); + let yy; + if wt >= qquo { + yy = wt - qquo; + } + else{ + yy = u64::MAX - qquo + wt + 1; + } + (xx, yy) +} + +// (X,Y) -> (X+Y, W(X-Y)) mod q +// 0 <= X, Y < 2q ==> 0 <= X', Y' < 2q +#[allow(non_snake_case)] +#[cfg(test)] +pub(crate) fn lazy_inverse_butterfly(X: &mut T, Y: &mut T, W: u64, Wprime: u64, q: &T) where T: SuperTrait{ + let mut xx = X.rep() + Y.rep(); + + let twoq = 2*q.rep(); + if xx > twoq { + xx -= twoq; + } + let t = twoq - Y.rep() + X.rep(); + let quo = super::util::mul_high_word(Wprime, t); + let yy = W.wrapping_mul(t) - quo.wrapping_mul(q.rep()); + *X = T::from(xx); + *Y = T::from(yy); +} + +#[cfg(test)] +mod tests { + use super::*; + use super::super::scalar::Scalar; + + fn butterfly_for_test(arr: [u64;4]) -> [u64; 2] { + let mut X:Scalar = Scalar::from(arr[0]); + let mut Y:Scalar = Scalar::from(arr[1]); + let W:Scalar = Scalar::from(arr[2]); + let q:Scalar = Scalar::new_modulus(arr[3]); + + butterfly(&mut X, &mut Y, &W, &q); + [X.into(), Y.into()] + } + + fn inverse_butterfly_for_test(arr: [u64;4]) -> [u64; 2] { + let mut X:Scalar = Scalar::from(arr[0]); + let mut Y:Scalar = Scalar::from(arr[1]); + let W:Scalar = Scalar::from(arr[2]); + let q:Scalar = Scalar::new_modulus(arr[3]); + + inverse_butterfly(&mut X, &mut Y, &W, &q); + [X.into(), Y.into()] + } + + fn lazy_butterfly_for_test(arr: [u64;4]) -> [u64; 2] { + let mut X:Scalar = Scalar::from(arr[0]); + let mut Y:Scalar = Scalar::from(arr[1]); + let W = arr[2]; + let q:Scalar = Scalar::new_modulus(arr[3]); + // W′ = ⌊W β/p⌋, 0 < W′ < β + let Wprime: u64 = super::super::util::compute_harvey_ratio(W, q.rep()); + let twoq = q.rep() << 1; + + lazy_butterfly(&mut X, &mut Y, W, Wprime, &q, twoq); + [X.into(), Y.into()] + } + + fn lazy_inverse_butterfly_for_test(arr: [u64;4]) -> [u64; 2] { + let mut X:Scalar = Scalar::from(arr[0]); + let mut Y:Scalar = Scalar::from(arr[1]); + let W = arr[2]; + let q:Scalar = Scalar::new_modulus(arr[3]); + // W′ = ⌊W β/p⌋, 0 < W′ < β + let Wprime: u64 = super::super::util::compute_harvey_ratio(W, q.rep()); + + lazy_inverse_butterfly(&mut X, &mut Y, W, Wprime, &q); + [X.into(), Y.into()] + } + + macro_rules! lazy_butterfly_tests { + ($($name:ident: $value:expr,)*) => { + $( + #[test] + fn $name() { + let input = $value; + let butterfly_out = butterfly_for_test(input); + let output = lazy_butterfly_for_test(input); + println!("{:?}", butterfly_out); + println!("{:?}", output); + assert!(output[0] < 4*input[3]); + assert!(output[1] < 4*input[3]); + assert_eq!((output[1] - butterfly_out[1]) % input[3], 0); + assert_eq!((output[0] - butterfly_out[0]) % input[3], 0); + } + )* + } + } + + macro_rules! lazy_inverse_butterfly_tests { + ($($name:ident: $value:expr,)*) => { + $( + #[test] + fn $name() { + let input = $value; + let butterfly_out = inverse_butterfly_for_test(input); + let output = lazy_inverse_butterfly_for_test(input); + println!("{:?}", butterfly_out); + println!("{:?}", output); + assert!(output[0] < 2*input[3]); + assert!(output[1] < 2*input[3]); + assert_eq!((output[1] - butterfly_out[1]) % input[3], 0); + assert_eq!((output[0] - butterfly_out[0]) % input[3], 0); + } + )* + } + } + + macro_rules! inverse_butterfly_tests { + ($($name:ident: $value:expr,)*) => { + $( + #[test] + fn $name() { + let (input, expected) = $value; + assert_eq!(expected, inverse_butterfly_for_test(input)); + } + )* + } + } + + macro_rules! butterfly_tests { + ($($name:ident: $value:expr,)*) => { + $( + #[test] + fn $name() { + let (input, expected) = $value; + assert_eq!(expected, butterfly_for_test(input)); + } + )* + } + } + + butterfly_tests! { + butterfly_0: ([0u64, 1u64, 0u64, 100u64], [0u64, 0u64]), + butterfly_1: ([1u64, 1u64, 1u64, 100u64], [2u64, 0u64]), + butterfly_2: ([50u64, 50u64, 1u64, 100u64], [0u64, 0u64]), + butterfly_3: ([1u64, 1u64, 50u64, 100u64], [51u64, 51u64]), + } + + inverse_butterfly_tests! { + inverse_butterfly_0: ([0u64, 1u64, 0u64, 100u64], [1u64, 0u64]), + inverse_butterfly_1: ([1u64, 1u64, 1u64, 100u64], [2u64, 0u64]), + inverse_butterfly_2: ([50u64, 50u64, 1u64, 100u64], [0u64, 0u64]), + inverse_butterfly_3: ([2u64, 1u64, 50u64, 100u64], [3u64, 50u64]), + } + + lazy_butterfly_tests! { + lazy_butterfly_0: ([0u64, 1u64, 0u64, 100u64]), + lazy_butterfly_1: ([1u64, 1u64, 1u64, 100u64]), + lazy_butterfly_2: ([50u64, 50u64, 1u64, 100u64]), + lazy_butterfly_3: ([1u64, 1u64, 50u64, 100u64]), + } + + lazy_inverse_butterfly_tests! { + lazy_inverse_butterfly_0: ([0u64, 1u64, 0u64, 100u64]), + lazy_inverse_butterfly_1: ([1u64, 1u64, 1u64, 100u64]), + lazy_inverse_butterfly_2: ([50u64, 50u64, 1u64, 100u64]), + lazy_inverse_butterfly_3: ([1u64, 1u64, 50u64, 100u64]), + } +} diff --git a/src/integer_arith/mod.rs b/src/integer_arith/mod.rs index 11e1126..d2b2387 100644 --- a/src/integer_arith/mod.rs +++ b/src/integer_arith/mod.rs @@ -3,11 +3,12 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. pub mod scalar; +pub mod butterfly; +pub mod util; #[cfg(feature = "bigint")] pub mod bigint; -use rand::StdRng; /// The trait for utility functions related to scalar-like types. pub trait ArithUtils { @@ -21,7 +22,7 @@ pub trait ArithUtils { // sample a value in [0, bound-1] fn sample_blw(bound: &T) -> T; - fn sample_below_from_rng(bound: &T, rng: &mut StdRng) -> T; + fn sample_below_from_rng(bound: &T, rng: &mut dyn Rng) -> T; fn one() -> T { Self::from_u32_raw(1u32) @@ -53,3 +54,15 @@ pub trait ArithUtils { fn from_u64_raw(a: u64) -> T; fn to_u64(a: &T) -> u64; } + +pub trait ArithOperators{ + fn add_u64(&mut self, a: u64); + + fn sub_u64(&mut self, a: u64); + + fn rep(&self) -> u64; +} + +pub trait SuperTrait: ArithOperators + ArithUtils + Clone + From + From + PartialEq{} + +pub trait Rng: rand::CryptoRng + rand::RngCore {} \ No newline at end of file diff --git a/src/integer_arith/scalar.rs b/src/integer_arith/scalar.rs index 604d58d..f758629 100644 --- a/src/integer_arith/scalar.rs +++ b/src/integer_arith/scalar.rs @@ -2,13 +2,17 @@ // // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -use crate::integer_arith::ArithUtils; +use crate::integer_arith::{ArithOperators, ArithUtils, SuperTrait}; use modinverse::modinverse; -use rand::rngs::StdRng; -use rand::FromEntropy; -use rand::RngCore; +use rand::rngs::{StdRng,ThreadRng}; +use rand::{FromEntropy}; +use super::Rng; +use ::std::ops; pub use std::sync::Arc; +impl Rng for StdRng {} +impl Rng for ThreadRng {} + /// The ScalarContext class contains useful auxilliary information for fast modular reduction against a Scalar instance. #[derive(Debug, PartialEq, Eq, Clone)] struct ScalarContext { @@ -62,12 +66,16 @@ impl Scalar { } } +/// Trait implementations +impl SuperTrait for Scalar {} + impl PartialEq for Scalar { fn eq(&self, other: &Self) -> bool { self.rep == other.rep } } +// Conversions impl From for Scalar { fn from(item: u32) -> Self { Scalar { context: None, rep: item as u64, bit_count: 0 } @@ -80,6 +88,63 @@ impl From for Scalar { } } +impl From for u64{ + fn from(item: Scalar) -> u64 { + item.rep + } +} + +// Operators +impl ops::Add<&Scalar> for Scalar { + type Output = Scalar; + fn add(self, v: &Scalar) -> Scalar { + Scalar::new(self.rep + v.rep) + } +} + +impl ops::Add for Scalar { + type Output = Scalar; + fn add(self, v: Scalar) -> Scalar { + self + &v + } +} + +impl ops::Sub<&Scalar> for Scalar { + type Output = Scalar; + fn sub(self, v: &Scalar) -> Scalar { + Scalar::new(self.rep - v.rep) + } +} + +impl ops::Sub for Scalar { + type Output = Scalar; + fn sub(self, v: Scalar) -> Scalar { + self - &v + } +} + +impl ops::Mul for Scalar { + type Output = Scalar; + fn mul(self, v: u64) -> Scalar { + Scalar::new(self.rep * v) + } +} + +impl ArithOperators for Scalar{ + fn add_u64(&mut self, a: u64){ + self.rep += a; + } + + fn sub_u64(&mut self, a: u64){ + self.rep -= a; + } + + fn rep(&self) -> u64{ + self.rep + } +} + +// Trait implementation impl ArithUtils for Scalar { fn new_modulus(q: u64) -> Scalar { Scalar { @@ -158,17 +223,15 @@ impl ArithUtils for Scalar { } // sample below using a given rng. - fn sample_below_from_rng(upper_bound: &Scalar, rng: &mut StdRng) -> Self { - loop { - let n = Self::_sample_form_rng(upper_bound.bit_count, rng); - if n < upper_bound.rep { - return Scalar::new(n); - } - } + fn sample_below_from_rng(upper_bound: &Scalar, rng: &mut dyn Rng) -> Self { + upper_bound.sample(rng) } fn modulus(a: &Scalar, q: &Scalar) -> Scalar { - Scalar::new(a.rep % q.rep) + match &q.context{ + Some(context) => {Scalar::from(Scalar::_barret_reduce((a.rep(), 0), context.barrett_ratio, q.rep()))} + None => Scalar::new(a.rep % q.rep) + } } fn mul(a: &Scalar, b: &Scalar) -> Scalar { @@ -202,7 +265,17 @@ impl Scalar { res } - fn _sample_form_rng(bit_size: usize, rng: &mut StdRng) -> u64 { + fn sample(&self, rng: &mut dyn Rng) -> Scalar { + let max_multiple = self.rep() * (u64::MAX / self.rep() ); + loop{ + let a = rng.next_u64(); + if a < max_multiple { + return Scalar::modulus(&Scalar::from(a), self); + } + } + } + + fn _sample_from_rng(bit_size: usize, rng: &mut dyn Rng) -> u64 { let bytes = (bit_size - 1) / 8 + 1; let mut buf: Vec = vec![0; bytes]; rng.fill_bytes(&mut buf); @@ -219,7 +292,7 @@ impl Scalar { fn _sample(bit_size: usize) -> u64 { let mut rng = StdRng::from_entropy(); - Self::_sample_form_rng(bit_size, &mut rng) + Self::_sample_from_rng(bit_size, &mut rng) } fn _sub_mod(a: &Scalar, b: &Scalar, q: u64) -> Self { @@ -251,9 +324,10 @@ impl Scalar { // compute w = a*ratio >> 128. // start with lw(a1r1) - // let mut w= Scalar::multiply_u64(a.1, ratio.1).0; - let mut w = a.1.wrapping_mul(ratio.1); - + let mut w = 0; + if a.1 != 0{ + w = a.1.wrapping_mul(ratio.1); + } let a0r0 = Scalar::_multiply_u64(a.0, ratio.0); let a0r1 = Scalar::_multiply_u64(a.0, ratio.1); @@ -266,11 +340,13 @@ impl Scalar { w += carry as u64; // Round2 - let a1r0 = Scalar::_multiply_u64(a.1, ratio.0); - w += a1r0.1; - // final carry - let (_, carry2) = Scalar::_add_u64(a1r0.0, tmp); - w += carry2 as u64; + if a.1 != 0{ + let a1r0 = Scalar::_multiply_u64(a.1, ratio.0); + w += a1r0.1; + // final carry + let (_, carry2) = Scalar::_add_u64(a1r0.0, tmp); + w += carry2 as u64; + } // low = w*q mod 2^64. // let low = Scalar::multiply_u64(w, q).0; @@ -339,6 +415,17 @@ mod tests { assert!(Scalar::sample_blw(&q_scalar).rep < q); } } + + #[test] + fn test_sample_below_prng() { + use rand::{thread_rng}; + let q: u64 = 18014398492704769; + let q_scalar = Scalar::new_modulus(q); + let mut rng = thread_rng(); + for _ in 0..10 { + assert!(Scalar::sample_below_from_rng(&q_scalar, &mut rng).rep < q); + } + } #[test] fn test_equality() { assert_eq!(Scalar::zero(), Scalar::zero()); @@ -427,4 +514,12 @@ mod tests { assert_eq!(c, 6); } + + #[test] + fn test_operator_add(){ + let a = Scalar::new(123); + let b = Scalar::new(123); + let c = a + &b; + assert_eq!(u64::from(c), 246u64); + } } diff --git a/src/integer_arith/util.rs b/src/integer_arith/util.rs new file mode 100644 index 0000000..cd3c6aa --- /dev/null +++ b/src/integer_arith/util.rs @@ -0,0 +1,38 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +/// computes floor(a*b/pow(2,64)) +pub fn mul_high_word(a: u64, b:u64) -> u64{ + ((a as u128 * b as u128) >> 64) as u64 +} + +/// computes floor(w*pow(2,64)/q) +pub fn compute_harvey_ratio(w: u64, q: u64) -> u64{ + (((w as u128) << 64 )/ q as u128) as u64 +} + +pub fn mul_low_word(a: u64, b: u64) -> u64 { + let res = (a as u128) * (b as u128); + (res >> 64) as u64 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_mul_high_word(){ + assert_eq!(mul_high_word(1,1), 0); + assert_eq!(mul_high_word(1u64 << 63,0), 0); + assert_eq!(mul_high_word(1u64 << 63,2), 1); + assert_eq!(mul_high_word(1u64 << 63,1u64 << 63), 1u64 << 62); + } + + #[test] + fn test_compute_harvey_ratio(){ + assert_eq!(compute_harvey_ratio(0,100), 0); + assert_eq!(compute_harvey_ratio(1,100), 184467440737095516); + } +} diff --git a/src/lib.rs b/src/lib.rs index f18214e..a764384 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -170,13 +170,21 @@ pub mod integer_arith; +pub mod polyarith; +#[cfg(feature = "bench")] +pub mod rqpoly; +#[cfg(not(feature = "bench"))] mod rqpoly; pub mod traits; mod serialize; mod utils; +#[cfg(feature = "bench")] +pub mod randutils; +#[cfg(not(feature = "bench"))] +mod randutils; use integer_arith::scalar::Scalar; -use integer_arith::ArithUtils; +use integer_arith::{SuperTrait, ArithUtils}; use traits::*; use std::sync::Arc; @@ -192,7 +200,7 @@ pub type DefaultShemeType = FV; /// SecretKey type pub struct SecretKey(RqPoly); -use rqpoly::{FiniteRingElt, RqPoly, RqPolyContext, NTT}; +use rqpoly::{FiniteRingElt, RqPoly, RqPolyContext}; pub fn default() -> DefaultShemeType { FV::::default_2048() @@ -252,7 +260,7 @@ where // add a plaintext into a FVCiphertext. fn add_plain_inplace(&self, ct: &mut FVCiphertext, pt: &FVPlaintext) { for (ct_coeff, pt_coeff) in ct.1.coeffs.iter_mut().zip(pt.iter()) { - let temp = T::mul(&pt_coeff, &self.delta); + let temp = T::mul(pt_coeff, &self.delta); *ct_coeff = T::add_mod(ct_coeff, &temp, &self.q); } } @@ -277,7 +285,7 @@ where impl AdditiveHomomorphicScheme, SecretKey> for FV where RqPoly: FiniteRingElt, - T: Clone + ArithUtils + PartialEq + From, + T: SuperTrait, { fn add_inplace(&self, ct1: &mut FVCiphertext, ct2: &FVCiphertext) { ct1.0.add_inplace(&ct2.0); @@ -292,7 +300,7 @@ where // add large noise poly for noise flooding. let elarge = - rqpoly::randutils::sample_gaussian_poly(self.context.clone(), self.flooding_stdev); + randutils::sample_gaussian_poly(self.context.clone(), self.flooding_stdev); ct.1.add_inplace(&elarge); } } @@ -300,7 +308,7 @@ where // constructor and random poly sampling impl FV where - T: ArithUtils + Clone + PartialEq + Serializable + From, + T: SuperTrait+ PartialEq + Serializable, RqPoly: FiniteRingElt + NTT, { pub fn new(n: usize, q: &T) -> Self { @@ -322,8 +330,8 @@ where n, t: t.clone(), flooding_stdev: 2f64.powi(40), - delta: T::div(q, &t), // &q/t, - qdivtwo: T::div(q, &T::from(2)), // &q/2, + delta: T::div(q, t), // &q/t, + qdivtwo: T::div(q, &T::from(2_u32)), // &q/2, q: q.clone(), stdev: 3.2, context, @@ -387,10 +395,10 @@ impl FV { impl KeyGeneration, SecretKey> for FV where RqPoly: FiniteRingElt, - T: Clone + ArithUtils + PartialEq + From, + T: SuperTrait, { fn generate_key(&self) -> SecretKey { - let mut skpoly = rqpoly::randutils::sample_ternary_poly(self.context.clone()); + let mut skpoly = randutils::sample_ternary_poly(self.context.clone()); if self.context.is_ntt_enabled { skpoly.forward_transform(); } @@ -411,21 +419,24 @@ where impl EncryptionOfZeros, SecretKey> for FV where RqPoly: FiniteRingElt, - T: Clone + ArithUtils + PartialEq + From, + T: SuperTrait, { fn encrypt_zero(&self, pk: &FVCiphertext) -> FVCiphertext { - let mut u = rqpoly::randutils::sample_ternary_poly_prng(self.context.clone()); - let e1 = rqpoly::randutils::sample_gaussian_poly(self.context.clone(), self.stdev); - let e2 = rqpoly::randutils::sample_gaussian_poly(self.context.clone(), self.stdev); + let mut u = randutils::sample_ternary_poly_prng(self.context.clone()); + let e1 = randutils::sample_gaussian_poly(self.context.clone(), self.stdev); + let e2 = randutils::sample_gaussian_poly(self.context.clone(), self.stdev); if self.context.is_ntt_enabled { u.forward_transform(); } // c0 = au + e1 + // let mut c0 = RqPoly::new(self.context.clone()); let mut c0 = (self.poly_multiplier)(&pk.0, &u); c0.add_inplace(&e1); // c1 = bu + e2 + // let mut c1 = RqPoly::new(self.context.clone()); + let mut c1 = (self.poly_multiplier)(&pk.1, &u); c1.add_inplace(&e2); @@ -433,8 +444,8 @@ where } fn encrypt_zero_sk(&self, sk: &SecretKey) -> FVCiphertext { - let e = rqpoly::randutils::sample_gaussian_poly(self.context.clone(), self.stdev); - let a = rqpoly::randutils::sample_uniform_poly(self.context.clone()); + let e = randutils::sample_gaussian_poly(self.context.clone(), self.stdev); + let a = randutils::sample_uniform_poly(self.context.clone()); let mut b = (self.poly_multiplier)(&a, &sk.0); b.add_inplace(&e); (a, b) @@ -444,7 +455,7 @@ where impl PKEncryption, FVPlaintext, SecretKey> for FV where RqPoly: FiniteRingElt, - T: Clone + ArithUtils + PartialEq + From, + T: SuperTrait, { fn encrypt(&self, pt: &FVPlaintext, pk: &FVCiphertext) -> FVCiphertext { // use public key to encrypt @@ -453,7 +464,7 @@ where // c1 = bu+e2 + Delta*m let iter = c1.coeffs.iter_mut().zip(pt.iter()); for (x, y) in iter { - let temp = T::mul(&y, &self.delta); + let temp = T::mul(y, &self.delta); *x = T::add_mod(x, &temp, &self.q); } (c0, c1) @@ -463,7 +474,7 @@ where impl PKEncryption, DefaultFVPlaintext, SecretKey> for FV where RqPoly: FiniteRingElt, - T: Clone + ArithUtils + PartialEq + From, + T: SuperTrait, { fn encrypt(&self, pt: &DefaultFVPlaintext, pk: &FVCiphertext) -> FVCiphertext { let pt1 = self.convert_pt_u8_to_scalar(pt); @@ -474,7 +485,7 @@ where impl SKEncryption, DefaultFVPlaintext, SecretKey> for FV where RqPoly: FiniteRingElt, - T: Clone + ArithUtils + PartialEq + From, + T: SuperTrait, { fn encrypt_sk(&self, pt: &DefaultFVPlaintext, sk: &SecretKey) -> FVCiphertext { @@ -492,11 +503,12 @@ where impl SKEncryption, FVPlaintext, SecretKey> for FV where RqPoly: FiniteRingElt, - T: Clone + ArithUtils + PartialEq + From, + T: SuperTrait, { fn encrypt_sk(&self, pt: &FVPlaintext, sk: &SecretKey) -> FVCiphertext { - let e = rqpoly::randutils::sample_gaussian_poly(self.context.clone(), self.stdev); - let a = rqpoly::randutils::sample_uniform_poly(self.context.clone()); + let e = randutils::sample_gaussian_poly(self.context.clone(), self.stdev); + let a = randutils::sample_uniform_poly(self.context.clone()); + let mut b = (self.poly_multiplier)(&a, &sk.0); b.add_inplace(&e); @@ -504,7 +516,7 @@ where // add scaled plaintext to let iter = b.coeffs.iter_mut().zip(pt.iter()); for (x, y) in iter { - let temp = T::mul(&y, &self.delta); + let temp = T::mul(y, &self.delta); *x = T::add_mod(x, &temp, &self.q); } (a, b) @@ -515,20 +527,22 @@ where let mut phase = ct.1.clone(); phase.sub_inplace(&temp1); // then, extract value from phase. - let mut c: Vec = vec![]; - for x in phase.coeffs { - // let mut tmp = x << 8; // x * t, need to make sure there's no overflow. - let mut tmp = T::mul(&x, &self.t); - // tmp += &self.qdivtwo; - tmp = T::add(&tmp, &self.qdivtwo); - // tmp /= &self.q; - tmp = T::div(&tmp, &self.q); - // modulo t. - tmp = T::modulus(&tmp, &self.t); - - c.push(tmp); - } - c + let tt: u64 = self.t.rep(); + let qq: u64 = self.q.rep(); + let qdivtwo = qq / 2; + + let my_closure = |elm: &T| -> T{ + let mut tmp:u64 = elm.rep(); + tmp *= tt; + tmp += qdivtwo; + tmp /= qq; + tmp %= tt; + T::from(tmp) + }; + + return phase.coeffs.iter() + .map(my_closure) + .collect(); } } diff --git a/src/polyarith/lazy_ntt.rs b/src/polyarith/lazy_ntt.rs new file mode 100644 index 0000000..b5ec561 --- /dev/null +++ b/src/polyarith/lazy_ntt.rs @@ -0,0 +1,46 @@ +use crate::integer_arith::butterfly::{lazy_butterfly_u64, lazy_inverse_butterfly_u64}; + +pub fn lazy_ntt_u64(vec: &mut [u64], roots: &[u64], scaled_roots: &[u64], q: u64){ + let n = vec.len(); + let twoq = q << 1; + + let mut t = n; + let mut m = 1; + while m < n { + t >>= 1; + for i in 0..m { + let j1 = 2 * i * t; + let j2 = j1 + t - 1; + let phi = roots[m + i]; + for j in j1..j2 + 1 { + let new = lazy_butterfly_u64(vec[j], vec[j+t], phi, scaled_roots[m+i], q, twoq); + vec[j] = new.0; + vec[j+t] = new.1; + } + } + m <<= 1; + } +} + +pub fn lazy_inverse_ntt_u64(vec: &mut [u64], invroots: &[u64], scaled_invroots: &[u64], q: u64){ + let twoq = q << 1; + + let mut t = 1; + let mut m = vec.len(); + while m > 1 { + let mut j1 = 0; + let h = m >> 1; + for i in 0..h { + let j2 = j1 + t - 1; + for j in j1..j2 + 1 { + // inverse butterfly + let new = lazy_inverse_butterfly_u64(vec[j], vec[j+t], invroots[h+i], scaled_invroots[h+i], q, twoq); + vec[j] = new.0; + vec[j+t] = new.1; + } + j1 += 2 * t; + } + t <<= 1; + m >>= 1; + } +} \ No newline at end of file diff --git a/src/polyarith/mod.rs b/src/polyarith/mod.rs new file mode 100644 index 0000000..6385be2 --- /dev/null +++ b/src/polyarith/mod.rs @@ -0,0 +1 @@ +pub mod lazy_ntt; diff --git a/src/randutils.rs b/src/randutils.rs new file mode 100644 index 0000000..f678e79 --- /dev/null +++ b/src/randutils.rs @@ -0,0 +1,97 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +/// Utility functions for generating random polynomials. +use rand::distributions::{Distribution, Normal}; +use rand::rngs::{OsRng, StdRng}; +use rand::FromEntropy; +use rand::{thread_rng, Rng}; + +use super::*; + +use crate::rqpoly::RqPolyContext; + +pub fn sample_ternary_poly(context: Arc>) -> RqPoly +where + T: SuperTrait, +{ + let mut rng = OsRng::new().unwrap(); + let q = context.q.rep() as i64; + let c = (0..context.n).map(|_| { + let mut t = rng.gen_range(-1i32, 2i32) as i64; + if t < 0{ + t += q; + } + T::from(t as u64) + }).collect::>(); + RqPoly { + coeffs: c, + is_ntt_form: false, + context: Some(context), + } +} + +pub fn sample_ternary_poly_prng(context: Arc>) -> RqPoly +where + T: SuperTrait, +{ + let mut rng = StdRng::from_entropy(); + let q = context.q.rep(); + let q_minus_one = q-1 as u64; + + let c = (0..context.n).map(|_| { + let t = rng.gen_range(-1i32, 2i32) as i64; + let mut s: u64 = t as u64; + if t < 0 { + s = q_minus_one; + } + T::from(s) + }).collect::>(); + + RqPoly { + coeffs: c, + is_ntt_form: false, + context: Some(context), + } +} + +/// Sample a polynomial with Gaussian coefficients in the ring Rq. +pub fn sample_gaussian_poly(context: Arc>, stdev: f64) -> RqPoly +where + T: SuperTrait, +{ + let normal = Normal::new(0.0, stdev); + let mut rng = thread_rng(); + let q: f64 = context.q.rep() as f64; + + let c = (0..context.n).map(|_| { + let mut tmp = normal.sample(&mut rng); + if tmp < 0.0 { + tmp += q; + } + T::from(tmp as u64) + }).collect::>(); + + RqPoly { + coeffs: c, + is_ntt_form: false, + context: Some(context), + } +} + +/// Sample a uniform polynomial in the ring Rq. +pub fn sample_uniform_poly(context: Arc>) -> RqPoly +where + T: SuperTrait, +{ + let mut rng = thread_rng(); + + let c: Vec = vec![0;context.n].iter().map(|_| T::sample_below_from_rng(&context.q, &mut rng)).collect(); + RqPoly { + coeffs: c, + is_ntt_form: false, + context: Some(context), + } +} diff --git a/src/rqpoly.rs b/src/rqpoly.rs index 0c46a37..8492695 100644 --- a/src/rqpoly.rs +++ b/src/rqpoly.rs @@ -2,37 +2,53 @@ // // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -use crate::integer_arith::ArithUtils; +use crate::integer_arith::{SuperTrait, ArithUtils}; +use crate::polyarith::lazy_ntt::{lazy_ntt_u64, lazy_inverse_ntt_u64}; +use crate::integer_arith::util::compute_harvey_ratio; use crate::utils::reverse_bits_perm; use std::sync::Arc; +use crate::traits::*; /// Holds the context information for RqPolys, including degree n, modulus q, and optionally precomputed /// roots of unity for NTT purposes. #[derive(Debug)] -pub(crate) struct RqPolyContext { +pub struct RqPolyContext { pub n: usize, pub q: T, pub is_ntt_enabled: bool, pub roots: Vec, pub invroots: Vec, + pub scaled_roots: Vec, // for use in lazy ntt + pub scaled_invroots: Vec, // for use in lazy inverse ntt } /// Polynomials in Rq = Zq[x]/(x^n + 1). #[derive(Clone, Debug)] pub struct RqPoly { - context: Option>>, + pub(crate) context: Option>>, pub coeffs: Vec, pub is_ntt_form: bool, } impl RqPoly where T:Clone{ - pub fn new(coeffs: &Vec, is_ntt_form:bool) -> Self{ + pub fn new_without_context(coeffs: &[T], is_ntt_form:bool) -> Self{ RqPoly{ context: None, coeffs: coeffs.to_vec(), is_ntt_form, } } +} + +impl RqPoly where T:Clone + ArithUtils{ + pub fn new(context: Arc>) -> Self{ + let n = context.n; + RqPoly{ + context: Some(context), + coeffs: vec![T::zero(); n], + is_ntt_form: false + } + } pub(crate) fn set_context(&mut self, context: Arc>){ self.context = Some(context); @@ -52,21 +68,6 @@ impl PartialEq for RqPoly where T: PartialEq { } } -/// Number-theoretic transform (NTT) and fast polynomial multiplication based on NTT. -pub trait NTT: Clone { - fn is_ntt_form(&self) -> bool; - - fn set_ntt_form(&mut self, value: bool); - - fn forward_transform(&mut self); - - fn inverse_transform(&mut self); - - fn coeffwise_multiply(&self, other: &Self) -> Self; - - fn multiply_fast(&self, other: &Self) -> Self; -} - /// Arithmetics on general ring elements. pub trait FiniteRingElt { fn add_inplace(&mut self, other: &Self); @@ -80,7 +81,7 @@ pub trait FiniteRingElt { impl RqPolyContext where - T: ArithUtils + PartialEq + Clone + From, + T: SuperTrait, { pub fn new(n: usize, q: &T) -> Self { let mut a = RqPolyContext { @@ -89,11 +90,29 @@ where is_ntt_enabled: false, invroots: vec![], roots: vec![], + scaled_roots: vec![], + scaled_invroots: vec![], }; a.compute_roots(); + a.compute_scaled_roots(); + a } + fn compute_scaled_roots(&mut self){ + if !self.is_ntt_enabled{ + return; + } + // compute scaled roots as wiprime = wi + for i in 0..self.n { + self.scaled_roots.push(T::from(compute_harvey_ratio(self.roots[i].rep(), self.q.rep()))); + } + + for i in 0..self.n { + self.scaled_invroots.push(T::from(compute_harvey_ratio(self.invroots[i].rep(), self.q.rep()))); + } + } + fn compute_roots(&mut self) { let mut roots = vec![]; @@ -102,24 +121,22 @@ where self.is_ntt_enabled = false; return; } - self.is_ntt_enabled = true; let phi = root.unwrap(); - + let mut s = T::one(); for _ in 0..self.n { roots.push(s.clone()); s = T::mul_mod(&s, &phi, &self.q); } - // now bit reverse a vector - reverse_bits_perm(&mut roots); self.roots = roots; - + let mut invroots: Vec = vec![]; for x in self.roots.iter() { invroots.push(T::inv_mod(x, &self.q)); } self.invroots = invroots; + self.is_ntt_enabled = true; } pub fn find_root(&self) -> Option { @@ -147,97 +164,91 @@ where } } -// NTT implementation -impl NTT for RqPoly +impl RqPoly where - T: ArithUtils + Clone, + T: SuperTrait { - fn is_ntt_form(&self) -> bool { - self.is_ntt_form - } - - fn set_ntt_form(&mut self, value: bool) { - self.is_ntt_form = value; - } - - fn forward_transform(&mut self) { + fn lazy_ntt(&mut self) + { let context = self.context.as_ref().unwrap(); if self.is_ntt_form { panic!("is already in ntt"); } + let q = context.q.rep(); - let n = context.n; - let q = context.q.clone(); + let mut coeffs_u64: Vec = self.coeffs.iter() + .map(|elm| elm.rep()) + .collect(); - let mut t = n; - let mut m = 1; - while m < n { - t >>= 1; - for i in 0..m { - let j1 = 2 * i * t; - let j2 = j1 + t - 1; - let phi = &context.roots[m + i]; - for j in j1..j2 + 1 { - let x = T::mul_mod(&self.coeffs[j + t], &phi, &q); - self.coeffs[j + t] = T::sub_mod(&self.coeffs[j], &x, &q); - self.coeffs[j] = T::add_mod(&self.coeffs[j], &x, &q); - } - } - m <<= 1; + let roots_u64: Vec = context.roots.iter() + .map(|elm| elm.rep()) + .collect(); + let scaledroots_u64: Vec = context.scaled_roots.iter() + .map(|elm| elm.rep()) + .collect(); + + lazy_ntt_u64(&mut coeffs_u64, &roots_u64, &scaledroots_u64, q); + + for (coeff, coeff_u64) in self.coeffs.iter_mut().zip(coeffs_u64.iter()){ + *coeff = T::modulus(&T::from(*coeff_u64), &context.q); } self.set_ntt_form(true); } - fn inverse_transform(&mut self) { + fn lazy_inverse_ntt(&mut self){ let context = self.context.as_ref().unwrap(); if !self.is_ntt_form { panic!("is already not in ntt"); } let n = context.n; let q = context.q.clone(); - - let mut t = 1; - let mut m = n; let ninv = T::inv_mod(&T::from_u32(n as u32, &q), &q); - while m > 1 { - let mut j1 = 0; - let h = m >> 1; - for i in 0..h { - let j2 = j1 + t - 1; - let s = &context.invroots[h + i]; - for j in j1..j2 + 1 { - let u = self.coeffs[j].clone(); - let v = self.coeffs[j + t].clone(); - self.coeffs[j] = T::add_mod(&u, &v, &q); - - let tmp = T::sub_mod(&u, &v, &q); - self.coeffs[j + t] = T::mul_mod(&tmp, &s, &q); - } - j1 += 2 * t; - } - t <<= 1; - m >>= 1; - } - for x in 0..n { - self.coeffs[x] = T::mul_mod(&ninv, &self.coeffs[x], &q); + + let mut coeffs_u64: Vec = self.coeffs.iter() + .map(|elm| elm.rep()) + .collect(); + + let invroots_u64: Vec = context.invroots.iter() + .map(|elm| elm.rep()) + .collect(); + + let scaled_invroots_u64: Vec = context.scaled_invroots.iter() + .map(|elm| elm.rep()) + .collect(); + + lazy_inverse_ntt_u64(&mut coeffs_u64, &invroots_u64, &scaled_invroots_u64, q.rep()); + + for (coeff, coeff_u64) in self.coeffs.iter_mut().zip(coeffs_u64.iter()){ + *coeff = T::mul_mod(&ninv, &T::from(*coeff_u64), &context.q); } self.set_ntt_form(false); } +} - fn coeffwise_multiply(&self, other: &Self) -> Self { - let context = self.context.as_ref().unwrap(); - let mut c = self.clone(); - for (inputs, cc) in self - .coeffs - .iter() - .zip(other.coeffs.iter()) - .zip(c.coeffs.iter_mut()) - { - *cc = T::mul_mod(inputs.0, inputs.1, &context.q); - } - c +// NTT implementation(lazy version) +impl NTT for RqPoly +where + T: SuperTrait +{ + fn is_ntt_form(&self) -> bool { + self.is_ntt_form + } + + fn set_ntt_form(&mut self, value: bool) { + self.is_ntt_form = value; + } + + fn forward_transform(&mut self) { + self.lazy_ntt() + } + + fn inverse_transform(&mut self) { + self.lazy_inverse_ntt() } +} +impl FastPolyMultiply for RqPoly +where T: SuperTrait{ fn multiply_fast(&self, other: &Self) -> Self { let mut a: Self = self.clone(); let mut b = other.clone(); @@ -252,8 +263,23 @@ where c.inverse_transform(); c } + + fn coeffwise_multiply(&self, other: &Self) -> Self { + let context = self.context.as_ref().unwrap(); + let mut c = self.clone(); + for (inputs, cc) in self + .coeffs + .iter() + .zip(other.coeffs.iter()) + .zip(c.coeffs.iter_mut()) + { + *cc = T::mul_mod(inputs.0, inputs.1, &context.q); + } + c + } } + impl FiniteRingElt for RqPoly where T: ArithUtils + Clone, @@ -310,100 +336,6 @@ where } } -/// Utility functions for generating random polynomials. -pub(crate) mod randutils { - use rand::distributions::{Distribution, Normal}; - use rand::rngs::{OsRng, StdRng}; - use rand::FromEntropy; - use rand::{thread_rng, Rng}; - use super::*; - - pub(crate) fn sample_ternary_poly(context: Arc>) -> RqPoly - where - T: ArithUtils + From, - { - let mut rng = OsRng::new().unwrap(); - let mut c = vec![]; - for _x in 0..context.n { - let t = rng.gen_range(-1i32, 2i32); - if t >= 0 { - c.push(T::from(t as u32)); - } else { - c.push(T::sub(&context.q, &T::one())); - } - } - RqPoly { - coeffs: c, - is_ntt_form: false, - context: Some(context.clone()), - } - } - - pub(crate) fn sample_ternary_poly_prng(context: Arc>) -> RqPoly - where - T: ArithUtils + From, - { - let mut rng = StdRng::from_entropy(); - let mut c = vec![]; - for _x in 0..context.n { - let t = rng.gen_range(-1i32, 2i32); - if t >= 0 { - c.push(T::from(t as u32)); - } else { - c.push(T::sub(&context.q, &T::one())); - } - } - RqPoly { - coeffs: c, - is_ntt_form: false, - context: Some(context), - } - } - - /// Sample a polynomial with Gaussian coefficients in the ring Rq. - pub(crate) fn sample_gaussian_poly(context: Arc>, stdev: f64) -> RqPoly - where - T: ArithUtils, - { - let mut c = vec![]; - let normal = Normal::new(0.0, stdev); - let mut rng = thread_rng(); - for _ in 0..context.n { - let tmp = normal.sample(&mut rng); - - // branch on sign - if tmp >= 0.0 { - c.push(T::from_u64_raw(tmp as u64)); - } else { - let neg = T::from_u64_raw(-tmp as u64); - c.push(T::sub(&context.q, &neg)); - } - } - RqPoly { - coeffs: c, - is_ntt_form: false, - context: Some(context), - } - } - - /// Sample a uniform polynomial in the ring Rq. - pub(crate) fn sample_uniform_poly(context: Arc>) -> RqPoly - where - T: ArithUtils + From, - { - let mut c = vec![]; - let mut rng = StdRng::from_entropy(); - for _x in 0..context.n { - c.push(T::sample_below_from_rng(&context.q, &mut rng)); - } - RqPoly { - coeffs: c, - is_ntt_form: false, - context: Some(context.clone()), - } - } -} - #[cfg(test)] mod tests { use super::*; @@ -447,7 +379,7 @@ mod tests { let q = Scalar::new_modulus(18014398492704769u64); let context = RqPolyContext::new(2048, &q); let arc = Arc::new(context); - let a = randutils::sample_uniform_poly(arc.clone()); + let a = crate::randutils::sample_uniform_poly(arc.clone()); let mut aa = a.clone(); aa.forward_transform(); aa.inverse_transform(); @@ -480,8 +412,8 @@ mod tests { let context = RqPolyContext::new(2048, &q); let arc = Arc::new(context); - let a = randutils::sample_uniform_poly(arc.clone()); - let b = randutils::sample_uniform_poly(arc.clone()); + let a = crate::randutils::sample_uniform_poly(arc.clone()); + let b = crate::randutils::sample_uniform_poly(arc.clone()); let c = a.multiply(&b); let c1 = a.multiply_fast(&b); assert_eq!(c.coeffs, c1.coeffs); @@ -492,4 +424,34 @@ mod tests { let context2 = RqPolyContext::new(4, &Scalar::new_modulus(12289)); assert_eq!(context2.find_root().unwrap(), Scalar::from_u64_raw(8246u64)); } + + #[test] + fn test_lazy_ntt(){ + let q = Scalar::new_modulus(18014398492704769u64); + let context = RqPolyContext::new(4, &q); + let arc = Arc::new(context); + let mut a = crate::randutils::sample_uniform_poly(arc.clone()); + let mut aa = a.clone(); + + aa.forward_transform(); + a.lazy_ntt(); + + // assert + assert_eq!(aa.coeffs, a.coeffs); + } + + #[test] + fn test_lazy_inverse_ntt(){ + let q = Scalar::new_modulus(18014398492704769u64); + let context = RqPolyContext::new(4, &q); + let arc = Arc::new(context); + let mut a = crate::randutils::sample_uniform_poly(arc.clone()); + a.set_ntt_form(true); + let mut aa = a.clone(); + aa.inverse_transform(); + a.lazy_inverse_ntt(); + + // assert + assert_eq!(aa.coeffs, a.coeffs); + } } diff --git a/src/serialize.rs b/src/serialize.rs index 25de2f6..8334b19 100644 --- a/src/serialize.rs +++ b/src/serialize.rs @@ -34,11 +34,9 @@ where T: Serializable + Clone, { fn to_bytes(&self) -> std::vec::Vec { - let mut vec: Vec = Vec::new(); - // push in the is ntt form. - vec.push(self.is_ntt_form as u8); - for i in 0..self.coeffs.len() { - let mut bytes = self.coeffs[i].to_bytes(); + let mut vec: Vec = vec![self.is_ntt_form as u8]; + for coeff in &self.coeffs { + let mut bytes = coeff.to_bytes(); vec.append(&mut bytes); } vec @@ -51,7 +49,7 @@ where coeffs.push(T::from_bytes(&bytes[i..i+8].to_vec())); i += 8; } - RqPoly::new(&coeffs, is_ntt_form) + RqPoly::new_without_context(&coeffs, is_ntt_form) } } @@ -89,7 +87,7 @@ mod tests { for i in 0..4 { coeffs.push(Scalar::from_u64_raw(i)); } - let testpoly = RqPoly::::new(&coeffs, false); + let testpoly = RqPoly::::new_without_context(&coeffs, false); let bytes = testpoly.to_bytes(); let deserialized = RqPoly::::from_bytes(&bytes); assert_eq!(testpoly, deserialized); diff --git a/src/traits.rs b/src/traits.rs index 598aa68..5bd1fb3 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -61,3 +61,20 @@ pub trait Serializable{ /// Deserialize from a vector of bytes. fn from_bytes(bytes: &Vec) -> Self; } + +/// Number-theoretic transform (NTT) and fast polynomial multiplication based on NTT. +pub trait NTT: Clone { + fn is_ntt_form(&self) -> bool; + + fn set_ntt_form(&mut self, value: bool); + + fn forward_transform(&mut self); + + fn inverse_transform(&mut self); +} + +pub trait FastPolyMultiply: NTT { + fn multiply_fast(&self, other: &Self) -> Self; + + fn coeffwise_multiply(&self, other: &Self) -> Self; +}