Skip to content

Commit bf26507

Browse files
committed
Keep Context with HeScheme generic instead of Scalar
1 parent 28c8e83 commit bf26507

File tree

12 files changed

+33
-27
lines changed

12 files changed

+33
-27
lines changed

Sources/HomomorphicEncryption/Bfv/Bfv.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import ModularArithmetic
1616

1717
/// Brakerski-Fan-Vercauteren cryptosystem.
1818
public enum Bfv<T: ScalarType>: HeScheme {
19-
public typealias Context = HomomorphicEncryption.Context<T>
19+
public typealias Context = HomomorphicEncryption.Context<Self>
2020

2121
public typealias Scalar = T
2222

Sources/HomomorphicEncryption/Context.swift

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,16 @@
1616
///
1717
/// HE operations are typically only supported between objects, such as ``Ciphertext``, ``Plaintext``,
1818
/// ``EvaluationKey``, ``SecretKey``, with the same context.
19-
public final class Context<Scalar: ScalarType>: Equatable, Sendable, HeContext {
19+
public final class Context<Scheme: HeScheme>: Equatable, Sendable, HeContext {
20+
public typealias Scalar = Scheme.Scalar
21+
22+
/// The (row, column) dimension counts for ``EncodeFormat/simd`` encoding.
23+
///
24+
/// If the HE scheme does not support ``EncodeFormat/simd`` encoding, returns `nil`.
25+
public var simdDimensions: SimdEncodingDimensions? {
26+
Scheme.encodeSimdDimensions(for: encryptionParameters)
27+
}
28+
2029
/// Encryption parameters.
2130
public let encryptionParameters: EncryptionParameters<Scalar>
2231

@@ -124,15 +133,6 @@ public final class Context<Scalar: ScalarType>: Equatable, Sendable, HeContext {
124133
lhs === rhs || lhs.encryptionParameters == rhs.encryptionParameters
125134
}
126135

127-
/// The (row, column) dimension counts for ``EncodeFormat/simd`` encoding.
128-
///
129-
/// If the HE scheme does not support ``EncodeFormat/simd`` encoding, returns `nil`.
130-
public func simdDimensions<Scheme: HeScheme>(for _: Scheme.Type) -> SimdEncodingDimensions?
131-
where Scheme.Scalar == Scalar
132-
{
133-
Scheme.encodeSimdDimensions(for: encryptionParameters)
134-
}
135-
136136
@inlinable
137137
public func _getRnsTool(moduliCount: Int) -> _RnsTool<Scalar> {
138138
precondition(moduliCount <= rnsTools.count && moduliCount > 0, "Invalid number of moduli")

Sources/HomomorphicEncryption/HeScheme.swift

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,19 +86,18 @@ public struct SimdEncodingDimensions: Codable, Equatable, Hashable, Sendable {
8686
}
8787

8888
public protocol HeContext: Equatable, Sendable, CustomStringConvertible {
89-
associatedtype Scalar: ScalarType
89+
associatedtype Scheme: HeScheme where Scheme.Scalar == Self.Scalar
90+
typealias Scalar = Scheme.Scalar
9091

9192
var encryptionParameters: EncryptionParameters<Scalar> { get }
9293
var ciphertextContext: PolyContext<Scalar> { get }
9394
var plaintextContext: PolyContext<Scalar> { get }
9495
var secretKeyContext: PolyContext<Scalar> { get }
96+
var simdDimensions: SimdEncodingDimensions? { get }
9597
var simdEncodingMatrix: [Int] { get }
9698

9799
init(encryptionParameters: EncryptionParameters<Scalar>) throws
98100
func _getRnsTool(moduliCount: Int) throws -> _RnsTool<Scalar>
99-
100-
func simdDimensions<Scheme: HeScheme>(for _: Scheme.Type) -> SimdEncodingDimensions?
101-
where Scheme.Scalar == Scalar
102101
}
103102

104103
extension HeContext {
@@ -116,6 +115,13 @@ extension HeContext {
116115
public var bitsPerPlaintext: Int { encryptionParameters.bitsPerPlaintext }
117116
/// The number of bytes that can be encoded in a single ``Plaintext``.
118117
public var bytesPerPlaintext: Int { encryptionParameters.bytesPerPlaintext }
118+
119+
/// The (row, column) dimension counts for ``EncodeFormat/simd`` encoding.
120+
///
121+
/// If the HE scheme does not support ``EncodeFormat/simd`` encoding, returns `nil`.
122+
public var simdDimensions: SimdEncodingDimensions? {
123+
Scheme.encodeSimdDimensions(for: encryptionParameters)
124+
}
119125
}
120126

121127
/// Protocol for HE schemes.

Sources/HomomorphicEncryption/NoOpScheme.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
/// The scheme simply takes the plaintext as a "ciphertext" and
1818
/// ignores any ciphertext coefficient moduli.
1919
public enum NoOpScheme: HeScheme {
20-
public typealias Context = HomomorphicEncryption.Context<UInt64>
20+
public typealias Context = HomomorphicEncryption.Context<Self>
2121
public typealias Scalar = UInt64
2222
public typealias CanonicalCiphertextFormat = Coeff
2323

Sources/PrivateNearestNeighborSearch/CiphertextMatrix.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ public struct CiphertextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable,
6666
throw PnnsError.emptyCiphertextArray
6767
}
6868
let encryptionParameters = context.encryptionParameters
69-
guard let simdDimensions = context.simdDimensions(for: Scheme.self) else {
69+
guard let simdDimensions = context.simdDimensions else {
7070
throw PnnsError.simdEncodingNotSupported(for: encryptionParameters)
7171
}
7272
let expectedCiphertextCount = try PlaintextMatrix<Scheme, Format>.plaintextCount(

Sources/PrivateNearestNeighborSearch/PlaintextMatrix.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ public struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable,
121121
}
122122
let context = plaintexts[0].context
123123
let encryptionParameters = context.encryptionParameters
124-
guard let simdDimensions = context.simdDimensions(for: Scheme.self) else {
124+
guard let simdDimensions = context.simdDimensions else {
125125
throw PnnsError.simdEncodingNotSupported(for: encryptionParameters)
126126
}
127127
let expectedPlaintextCount = try PlaintextMatrix.plaintextCount(
@@ -276,7 +276,7 @@ public struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable,
276276
values: [Scalar]) throws -> [Scheme.CoeffPlaintext]
277277
{
278278
let degree = context.degree
279-
guard let simdColumnCount = context.simdDimensions(for: Scheme.self)?.columnCount else {
279+
guard let simdColumnCount = context.simdDimensions?.columnCount else {
280280
throw PnnsError.simdEncodingNotSupported(for: context.encryptionParameters)
281281
}
282282

@@ -334,7 +334,7 @@ public struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable,
334334
values: [Scalar]) throws -> [Plaintext<Scheme, Coeff>]
335335
{
336336
let encryptionParameters = context.encryptionParameters
337-
guard let simdDimensions = context.simdDimensions(for: Scheme.self) else {
337+
guard let simdDimensions = context.simdDimensions else {
338338
throw PnnsError.simdEncodingNotSupported(for: encryptionParameters)
339339
}
340340
guard simdDimensions.rowCount == 2 else {
@@ -411,7 +411,7 @@ public struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable,
411411
values: [Scalar]) throws -> [Scheme.CoeffPlaintext]
412412
{
413413
let encryptionParameters = context.encryptionParameters
414-
guard let simdDimensions = context.simdDimensions(for: Scheme.self) else {
414+
guard let simdDimensions = context.simdDimensions else {
415415
throw PnnsError.simdEncodingNotSupported(for: encryptionParameters)
416416
}
417417
let simdColumnCount = simdDimensions.columnCount

Sources/TestUtilities/HeApiTestUtils.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ public enum HeAPITestHelpers {
128128
}
129129

130130
/// test the evaluation key configuration
131-
public static func schemeEvaluationKeyTest(context _: Context<some ScalarType>) throws {
131+
public static func schemeEvaluationKeyTest(context _: Context<some HeScheme>) throws {
132132
do {
133133
let config = EvaluationKeyConfig()
134134
#expect(!config.hasRelinearizationKey)

Sources/TestUtilities/PirUtilities/IndexPirTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ extension PirTestUtils {
2222
/// Testing client configuration.
2323
@inlinable
2424
func generateParameter() throws {
25-
let context: Context<UInt64> = try TestUtils.getTestContext()
25+
let context: Context<Bfv<UInt64>> = try TestUtils.getTestContext()
2626
// unevenDimensions: false
2727
do {
2828
let config = try IndexPirConfig(entryCount: 16,

Tests/HomomorphicEncryptionTests/EncryptionParametersTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ struct EncryptionParametersTests {
243243

244244
// check skipLSBsForDecryption
245245
do {
246-
let context = try Context<UInt64>(encryptionParameters: params)
246+
let context = try Context<Bfv<UInt64>>(encryptionParameters: params)
247247
let data = TestUtils.getRandomPlaintextData(count: params.polyDegree, in: 0..<params.plaintextModulus)
248248
let plaintext: Plaintext<Bfv<UInt64>, Coeff> = try context.encode(values: data, format: .coefficient)
249249
let secretKey: SecretKey<Bfv<UInt64>> = try context.generateSecretKey()

Tests/HomomorphicEncryptionTests/HeAPITests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ struct HeAPITests {
8686

8787
@Test
8888
func testNoOpScheme() async throws {
89-
let context: Context<NoOpScheme.Scalar> = try TestUtils.getTestContext()
89+
let context: Context<NoOpScheme> = try TestUtils.getTestContext()
9090
try HeAPITestHelpers.schemeEncodeDecodeTest(context: context, scheme: NoOpScheme.self)
9191
try HeAPITestHelpers.schemeEncryptDecryptTest(context: context, scheme: NoOpScheme.self)
9292
try HeAPITestHelpers.schemeEncryptZeroDecryptTest(context: context, scheme: NoOpScheme.self)

0 commit comments

Comments
 (0)