Skip to content

Commit

Permalink
Flexible plaintext mod (#4)
Browse files Browse the repository at this point in the history
Added support for flexible plaintext modulus. Valid plaintext modulus can be any integer up to 1024.
  • Loading branch information
haochenuw authored May 18, 2021
1 parent 381b99e commit c8ee75a
Show file tree
Hide file tree
Showing 8 changed files with 322 additions and 148 deletions.
2 changes: 1 addition & 1 deletion examples/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ fn main() {

let mut ctv = fv.encrypt(&v, &pk);

let mut pt_actual = fv.decrypt(&ctv, &sk);
let mut pt_actual: Vec<u8> = fv.decrypt(&ctv, &sk);
print!("decrypted v: ");
smartprint(&pt_actual);

Expand Down
6 changes: 3 additions & 3 deletions examples/rerandomize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
use cupcake::traits::{AdditiveHomomorphicScheme, PKEncryption, SKEncryption};
use cupcake::traits::{AdditiveHomomorphicScheme, PKEncryption, SKEncryption, KeyGeneration};

fn smartprint<T: std::fmt::Debug>(v: &Vec<T>) {
println!("[{:?}, {:?}, ..., {:?}]", v[0], v[1], v[v.len() - 1]);
Expand All @@ -21,15 +21,15 @@ fn main() {

let mut ctv = fv.encrypt(&v, &pk);

let pt_original = fv.decrypt(&ctv, &sk);
let pt_original: Vec<u8> = fv.decrypt(&ctv, &sk);
print!("decrypted value: ");
smartprint(&pt_original);

println!("Rerandomizing the ciphertext...");

fv.rerandomize(&mut ctv, &pk);
print!("decrypted value after reranromization: ");
let pt_new = fv.decrypt(&ctv, &sk);
let pt_new: Vec<u8> = fv.decrypt(&ctv, &sk);
smartprint(&pt_new);

print!("Check that the plaintext has not changed...");
Expand Down
2 changes: 1 addition & 1 deletion examples/serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ fn main() {
println!("Adding the deserialized ciphertexts...");
fv.add_inplace(&mut deserialized_ctv, &deserialized_ctw);
println!("Decrypting the sum...");
let pt_actual = fv.decrypt(&deserialized_ctv, &sk);
let pt_actual: Vec<u8> = fv.decrypt(&deserialized_ctv, &sk);
println!("decrypted v+w: ");
smartprint(&pt_actual);
}
6 changes: 5 additions & 1 deletion src/integer_arith/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ pub mod bigint;
use rand::StdRng;
/// The trait for utility functions related to scalar-like types.
pub trait ArithUtils<T> {

/// Construct a new "modulus", which is a u64 plus information needed for fast modular reduction.
fn new_modulus(a: u64) -> T;

fn modulus(a: &T, q: &T) -> T;

fn double(a: &T) -> T;
Expand Down Expand Up @@ -47,5 +51,5 @@ pub trait ArithUtils<T> {
// conversion
fn from_u32_raw(a: u32) -> T;
fn from_u64_raw(a: u64) -> T;
fn to_u64(a: T) -> u64;
fn to_u64(a: &T) -> u64;
}
41 changes: 26 additions & 15 deletions src/integer_arith/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,6 @@ impl Scalar {
}
}

/// Construct a new "modulus", which is a u64 plus information needed for fast modular reduction.
pub fn new_modulus(q: u64) -> Self {
Scalar {
rep: q,
context: Some(ScalarContext::new(q)),
bit_count: 64 - q.leading_zeros() as usize,
}
}

pub fn rep(&self) -> u64{
self.rep
}
Expand All @@ -77,7 +68,27 @@ impl PartialEq for Scalar {
}
}

impl From<u32> for Scalar {
fn from(item: u32) -> Self {
Scalar { context: None, rep: item as u64, bit_count: 0 }
}
}

impl From<u64> for Scalar {
fn from(item: u64) -> Self {
Scalar { context: None, rep: item, bit_count: 0 }
}
}

impl ArithUtils<Scalar> for Scalar {
fn new_modulus(q: u64) -> Scalar {
Scalar {
rep: q,
context: Some(ScalarContext::new(q)),
bit_count: 64 - q.leading_zeros() as usize,
}
}

fn sub(a: &Scalar, b: &Scalar) -> Scalar {
Scalar::new(a.rep - b.rep)
}
Expand Down Expand Up @@ -164,7 +175,7 @@ impl ArithUtils<Scalar> for Scalar {
Scalar::new(a.rep * b.rep)
}

fn to_u64(a: Scalar) -> u64 {
fn to_u64(a: &Scalar) -> u64 {
a.rep
}

Expand Down Expand Up @@ -294,16 +305,16 @@ mod tests {
use super::*;
#[test]
fn test_bitlength() {
assert_eq!(Scalar::from_u32_raw(2).bit_length(), 2);
assert_eq!(Scalar::from_u32_raw(16).bit_length(), 5);
assert_eq!(Scalar::from(2u32).bit_length(), 2);
assert_eq!(Scalar::from(16u32).bit_length(), 5);
assert_eq!(Scalar::from_u64_raw(18014398492704769u64).bit_length(), 54);
}

#[test]
fn test_getbits() {
assert_eq!(Scalar::from_u32_raw(1).get_bits(), vec![true]);
assert_eq!(Scalar::from_u32_raw(2).get_bits(), vec![false, true]);
assert_eq!(Scalar::from_u32_raw(5).get_bits(), vec![true, false, true]);
assert_eq!(Scalar::from(1u32).get_bits(), vec![true]);
assert_eq!(Scalar::from(2u32).get_bits(), vec![false, true]);
assert_eq!(Scalar::from(5u32).get_bits(), vec![true, false, true]);
assert_eq!(
Scalar::from_u64_raw(127).get_bits(),
vec![true, true, true, true, true, true, true]
Expand Down
Loading

0 comments on commit c8ee75a

Please sign in to comment.