# CryptoSystem Interface

The `CryptoSystem` interface provides a uniform API for cryptographic operations such as key generation, encryption, decryption, and homomorphic arithmetic operations. It is designed to abstract different underlying cryptosystem implementations, allowing for CPU or GPU-based backends while maintaining the same interface for users.

**Note**: The current `CPUCryptoSystem` class uses pointers for underlying implementations in tensors and vectors. This will be updated in the future to strictly follow the interface described below.

### Template Parameters

```cpp
template <
    typename CryptoSystemImpl,
    typename SecretKeyImpl,
    typename PublicKeyImpl,
    typename PlainTextImpl,
    typename CipherTextImpl,
    typename SecretKeyShareImpl,
    typename PartDecryptionResult
>
```

* **CryptoSystemImpl**: The implementation of the cryptosystem.
* **SecretKeyImpl**: The type representing the secret key.
* **PublicKeyImpl**: The type representing the public key.
* **PlainTextImpl**: The type representing plaintext data.
* **CipherTextImpl**: The type representing ciphertext data.
* **SecretKeyShareImpl**: The type representing a share of the secret key (used in distributed settings).
* **PartDecryptionResult**: The type representing the result of a partial decryption.

### Interface Functions

#### Key Generation

* **`SecretKeyImpl keygen()`**

  Generates a new secret key.

  ```cpp
  SecretKeyImpl sk = cs.keygen();
  ```
* **`PublicKeyImpl keygen(SecretKeyImpl sk)`**

  Generates a public key from the given secret key.

  ```cpp
  PublicKeyImpl pk = cs.keygen(sk);
  ```
* **`Vector<Vector<SecretKeyShareImpl>> keygen(SecretKeyImpl sk, int threshold, int num_parties)`**

  Generates secret key shares for threshold decryption in a distributed setting.

  ```cpp
  auto sk_shares = cs.keygen(sk, threshold, num_parties);
  ```

#### Encryption

* **`CipherTextImpl encrypt(PublicKeyImpl pk, PlainTextImpl pt)`**

  Encrypts a plaintext using the public key.

  ```cpp
  CipherTextImpl ct = cs.encrypt(pk, pt);
  ```
* **`Vector<CipherTextImpl> encrypt(PublicKeyImpl pk, Vector<PlainTextImpl> pts)`**

  Encrypts a vector of plaintexts.

  ```cpp
  Vector<CipherTextImpl> cts = cs.encrypt(pk, pts);
  ```
* **`Tensor<CipherTextImpl> encrypt_tensor(PublicKeyImpl pk, Tensor<PlainTextImpl> pts)`**

  Encrypts a tensor of plaintexts.

  ```cpp
  Tensor<CipherTextImpl> cts = cs.encrypt_tensor(pk, pts);
  ```

#### Decryption

* **`PlainTextImpl decrypt(SecretKeyImpl sk, CipherTextImpl ct)`**

  Decrypts a ciphertext using the secret key.

  ```cpp
  PlainTextImpl pt = cs.decrypt(sk, ct);
  ```
* **`Vector<PlainTextImpl> decrypt_vector(SecretKeyImpl sk, Vector<CipherTextImpl> cts)`**

  Decrypts a vector of ciphertexts.

  ```cpp
  Vector<PlainTextImpl> pts = cs.decrypt_vector(sk, cts);
  ```
* **`Tensor<PlainTextImpl> decrypt_tensor(SecretKeyImpl sk, Tensor<CipherTextImpl> cts)`**

  Decrypts a tensor of ciphertexts.

  ```cpp
  Tensor<PlainTextImpl> pts = cs.decrypt_tensor(sk, cts);
  ```

#### Partial Decryption (Distributed)

* **`PartDecryptionResult part_decrypt(SecretKeyShareImpl sks, CipherTextImpl ct)`**

  Performs partial decryption using a secret key share.

  ```cpp
  PartDecryptionResult pdr = cs.part_decrypt(sks, ct);
  ```
* **`Vector<PartDecryptionResult> part_decrypt_vector(SecretKeyShareImpl sks, Vector<CipherTextImpl> cts)`**

  Partial decryption of a vector of ciphertexts.

  ```cpp
  Vector<PartDecryptionResult> pdrs = cs.part_decrypt_vector(sks, cts);
  ```
* **`Tensor<PartDecryptionResult> part_decrypt_tensor(SecretKeyShareImpl sks, Tensor<CipherTextImpl> cts)`**

  Partial decryption of a tensor of ciphertexts.

  ```cpp
  Tensor<PartDecryptionResult> pdrs = cs.part_decrypt_tensor(sks, cts);
  ```
* **`PlainTextImpl combine_part_decryption_results(CipherTextImpl ct, Vector<PartDecryptionResult> pdrs)`**

  Combines partial decryption results to recover the plaintext.

  ```cpp
  PlainTextImpl pt = cs.combine_part_decryption_results(ct, pdrs);
  ```
* **`Vector<PlainTextImpl> combine_part_decryption_results_vector(CipherTextImpl ct, Vector<PartDecryptionResult> pdrs)`**

  Combines partial decryption results for a vector of ciphertexts.

  ```cpp
  Vector<PlainTextImpl> pts = cs.combine_part_decryption_results_vector(ct, pdrs);
  ```
* **`Tensor<PlainTextImpl> combine_part_decryption_results_tensor(CipherTextImpl ct, Vector<Tensor<PartDecryptionResult>> pdrs)`**

  Combines partial decryption results for a tensor of ciphertexts.

  ```cpp
  Tensor<PlainTextImpl> pts = cs.combine_part_decryption_results_tensor(ct, pdrs);
  ```

#### Homomorphic Operations on Ciphertexts

* **`CipherTextImpl add_ciphertexts(PublicKeyImpl pk, CipherTextImpl ct1, CipherTextImpl ct2)`**

  Adds two ciphertexts homomorphically.

  ```cpp
  CipherTextImpl ct_sum = cs.add_ciphertexts(pk, ct1, ct2);
  ```
* **`CipherTextImpl scal_ciphertext(PublicKeyImpl pk, PlainTextImpl pt, CipherTextImpl ct)`**

  Multiplies a ciphertext by a plaintext scalar homomorphically.

  ```cpp
  CipherTextImpl ct_scaled = cs.scal_ciphertext(pk, pt, ct);
  ```
* **`Vector<CipherTextImpl> add_ciphertext_vectors(PublicKeyImpl pk, Vector<CipherTextImpl> cts1, Vector<CipherTextImpl> cts2)`**

  Adds two vectors of ciphertexts element-wise.

  ```cpp
  Vector<CipherTextImpl> cts_sum = cs.add_ciphertext_vectors(pk, cts1, cts2);
  ```
* **`Vector<CipherTextImpl> scal_ciphertext_vector(PublicKeyImpl pk, PlainTextImpl pt, Vector<CipherTextImpl> cts)`**

  Multiplies each ciphertext in the vector by a plaintext scalar.

  ```cpp
  Vector<CipherTextImpl> cts_scaled = cs.scal_ciphertext_vector(pk, pt, cts);
  ```
* **`Vector<CipherTextImpl> scal_ciphertext_vector(PublicKeyImpl pk, Vector<PlainTextImpl> pts, Vector<CipherTextImpl> cts)`**

  Multiplies each ciphertext in the vector by corresponding plaintext scalars.

  ```cpp
  Vector<CipherTextImpl> cts_scaled = cs.scal_ciphertext_vector(pk, pts, cts);
  ```
* **`Tensor<CipherTextImpl> add_ciphertext_tensors(PublicKeyImpl pk, Tensor<CipherTextImpl> cts1, Tensor<CipherTextImpl> cts2)`**

  Adds two tensors of ciphertexts element-wise.

  ```cpp
  Tensor<CipherTextImpl> cts_sum = cs.add_ciphertext_tensors(pk, cts1, cts2);
  ```
* **`Tensor<CipherTextImpl> scal_ciphertext_tensors(PublicKeyImpl pk, Tensor<PlainTextImpl> pts, Tensor<CipherTextImpl> cts)`**

  Multiplies each ciphertext in the tensor by corresponding plaintext scalars.

  ```cpp
  Tensor<CipherTextImpl> cts_scaled = cs.scal_ciphertext_tensors(pk, pts, cts);
  ```

#### Plaintext Operations

* **`PlainTextImpl generate_random_plaintext()`**

  Generates a random plaintext.

  ```cpp
  PlainTextImpl random_pt = cs.generate_random_plaintext();
  ```
* **`Vector<PlainTextImpl> generate_random_beavers_triplet()`**

  Generates a Beaver's triplet for secure multiplication.

  ```cpp
  Vector<PlainTextImpl> triplet = cs.generate_random_beavers_triplet();
  ```
* **`PlainTextImpl add_plaintexts(PlainTextImpl pt1, PlainTextImpl pt2)`**

  Adds two plaintexts.

  ```cpp
  PlainTextImpl pt_sum = cs.add_plaintexts(pt1, pt2);
  ```
* **`PlainTextImpl multiply_plaintexts(PlainTextImpl pt1, PlainTextImpl pt2)`**

  Multiplies two plaintexts.

  ```cpp
  PlainTextImpl pt_product = cs.multiply_plaintexts(pt1, pt2);
  ```
* **`Tensor<PlainTextImpl> add_plaintext_tensors(Tensor<PlainTextImpl> pts1, Tensor<PlainTextImpl> pts2)`**

  Adds two tensors of plaintexts element-wise.

  ```cpp
  Tensor<PlainTextImpl> pts_sum = cs.add_plaintext_tensors(pts1, pts2);
  ```
* **`Tensor<PlainTextImpl> multiply_plaintext_tensors(Tensor<PlainTextImpl> pts1, Tensor<PlainTextImpl> pts2)`**

  Multiplies two tensors of plaintexts element-wise.

  ```cpp
  Tensor<PlainTextImpl> pts_product = cs.multiply_plaintext_tensors(pts1, pts2);
  ```

#### Negation Operations

* **`PlainTextImpl negate_plaintext(PlainTextImpl pt)`**

  Negates a plaintext.

  ```cpp
  PlainTextImpl pt_neg = cs.negate_plaintext(pt);
  ```
* **`Tensor<PlainTextImpl> negate_plain_tensor(Tensor<PlainTextImpl> pts)`**

  Negates each plaintext in a tensor.

  ```cpp
  Tensor<PlainTextImpl> pts_neg = cs.negate_plain_tensor(pts);
  ```
* **`CipherTextImpl negate_ciphertext(PublicKeyImpl pk, CipherTextImpl ct)`**

  Negates a ciphertext homomorphically.

  ```cpp
  CipherTextImpl ct_neg = cs.negate_ciphertext(pk, ct);
  ```
* **`Tensor<CipherTextImpl> negate_ciphertext_tensor(PublicKeyImpl pk, Tensor<CipherTextImpl> cts)`**

  Negates each ciphertext in a tensor homomorphically.

  ```cpp
  Tensor<CipherTextImpl> cts_neg = cs.negate_ciphertext_tensor(pk, cts);
  ```

#### Plaintext Creation and Retrieval

* **`PlainTextImpl make_plaintext(float value)`**

  Creates a plaintext from a floating-point value.

  ```cpp
  PlainTextImpl pt = cs.make_plaintext(3.14f);
  ```
* **`float get_float_from_plaintext(PlainTextImpl pt)`**

  Retrieves the floating-point value from a plaintext.

  ```cpp
  float value = cs.get_float_from_plaintext(pt);
  ```

#### Serialization

* **`String serialize()`**

  Serializes the cryptosystem instance.

  ```cpp
  String serialized_cs = cs.serialize();
  ```
* **`String serialize_secret_key(SecretKeyImpl sk)`**

  Serializes a secret key.

  ```cpp
  String serialized_sk = cs.serialize_secret_key(sk);
  ```
* **`String serialize_secret_key_share(SecretKeyShareImpl sks)`**

  Serializes a secret key share.

  ```cpp
  String serialized_sks = cs.serialize_secret_key_share(sks);
  ```
* **`String serialize_public_key(PublicKeyImpl pk)`**

  Serializes a public key.

  ```cpp
  String serialized_pk = cs.serialize_public_key(pk);
  ```
* **`String serialize_plaintext(PlainTextImpl pt)`**

  Serializes a plaintext.

  ```cpp
  String serialized_pt = cs.serialize_plaintext(pt);
  ```
* **`String serialize_ciphertext(CipherTextImpl ct)`**

  Serializes a ciphertext.

  ```cpp
  String serialized_ct = cs.serialize_ciphertext(ct);
  ```
* **`String serialize_part_decryption_result(PartDecryptionResult pdr)`**

  Serializes a partial decryption result.

  ```cpp
  String serialized_pdr = cs.serialize_part_decryption_result(pdr);
  ```
* **`String serialize_plaintext_tensor(Tensor<PlainTextImpl> pts)`**

  Serializes a tensor of plaintexts.

  ```cpp
  String serialized_pts = cs.serialize_plaintext_tensor(pts);
  ```
* **`String serialize_ciphertext_tensor(Tensor<CipherTextImpl> cts)`**

  Serializes a tensor of ciphertexts.

  ```cpp
  String serialized_cts = cs.serialize_ciphertext_tensor(cts);
  ```
* **`String serialize_part_decryption_result_tensor(Tensor<PartDecryptionResult> pdrs)`**

  Serializes a tensor of partial decryption results.

  ```cpp
  String serialized_pdrs = cs.serialize_part_decryption_result_tensor(pdrs);
  ```

#### Deserialization

* **`CryptoSystemImpl::deserialize(String data)`**

  Deserializes a cryptosystem instance from a string.

  ```cpp
  CryptoSystemImpl cs = CryptoSystemImpl::deserialize(serialized_cs);
  ```
* **`SecretKeyImpl deserialize_secret_key(String data)`**

  Deserializes a secret key from a string.

  ```cpp
  SecretKeyImpl sk = cs.deserialize_secret_key(serialized_sk);
  ```
* **`SecretKeyShareImpl deserialize_secret_key_share(String data)`**

  Deserializes a secret key share from a string.

  ```cpp
  SecretKeyShareImpl sks = cs.deserialize_secret_key_share(serialized_sks);
  ```
* **`PublicKeyImpl deserialize_public_key(String data)`**

  Deserializes a public key from a string.

  ```cpp
  PublicKeyImpl pk = cs.deserialize_public_key(serialized_pk);
  ```
* **`PlainTextImpl deserialize_plaintext(String data)`**

  Deserializes a plaintext from a string.

  ```cpp
  PlainTextImpl pt = cs.deserialize_plaintext(serialized_pt);
  ```
* **`CipherTextImpl deserialize_ciphertext(String data)`**

  Deserializes a ciphertext from a string.

  ```cpp
  CipherTextImpl ct = cs.deserialize_ciphertext(serialized_ct);
  ```
* **`PartDecryptionResult deserialize_part_decryption_result(String data)`**

  Deserializes a partial decryption result from a string.

  ```cpp
  PartDecryptionResult pdr = cs.deserialize_part_decryption_result(serialized_pdr);
  ```
* **`Tensor<PlainTextImpl> deserialize_plaintext_tensor(String data)`**

  Deserializes a tensor of plaintexts from a string.

  ```cpp
  Tensor<PlainTextImpl> pts = cs.deserialize_plaintext_tensor(serialized_pts);
  ```
* **`Tensor<CipherTextImpl> deserialize_ciphertext_tensor(String data)`**

  Deserializes a tensor of ciphertexts from a string.

  ```cpp
  Tensor<CipherTextImpl> cts = cs.deserialize_ciphertext_tensor(serialized_cts);
  ```
* **`Tensor<PartDecryptionResult> deserialize_part_decryption_result_tensor(String data)`**

  Deserializes a tensor of partial decryption results from a string.

  ```cpp
  Tensor<PartDecryptionResult> pdrs = cs.deserialize_part_decryption_result_tensor(serialized_pdrs);
  ```

### Notes

* Serialization and deserialization facilitate the storage and transmission over network of cryptographic objects.

### Example Usage

#### Key Generation and Encryption

```cpp
// Generate keys
SecretKeyImpl sk = cs.keygen();
PublicKeyImpl pk = cs.keygen(sk);

// Create a plaintext
PlainTextImpl pt = cs.make_plaintext(42.0f);

// Encrypt the plaintext
CipherTextImpl ct = cs.encrypt(pk, pt);
```

#### Homomorphic Addition and Decryption

```cpp
// Encrypt two plaintexts
PlainTextImpl pt1 = cs.make_plaintext(10.0f);
PlainTextImpl pt2 = cs.make_plaintext(15.0f);
CipherTextImpl ct1 = cs.encrypt(pk, pt1);
CipherTextImpl ct2 = cs.encrypt(pk, pt2);

// Perform homomorphic addition
CipherTextImpl ct_sum = cs.add_ciphertexts(pk, ct1, ct2);

// Decrypt the result
PlainTextImpl pt_sum = cs.decrypt(sk, ct_sum);
float result = cs.get_float_from_plaintext(pt_sum); // result should be 25.0f
```

#### Serialization and Deserialization

```cpp
// Serialize keys and ciphertext
String serialized_sk = cs.serialize_secret_key(sk);
String serialized_pk = cs.serialize_public_key(pk);
String serialized_ct = cs.serialize_ciphertext(ct);

// Deserialize keys and ciphertext
SecretKeyImpl deserialized_sk = cs.deserialize_secret_key(serialized_sk);
PublicKeyImpl deserialized_pk = cs.deserialize_public_key(serialized_pk);
CipherTextImpl deserialized_ct = cs.deserialize_ciphertext(serialized_ct);
```
