-
Notifications
You must be signed in to change notification settings - Fork 122
Description
In working on #2561, my main test case was a single loop that repeatedly mul's the input ciphertext
func.func @loop_mul(%arg0: !secret.secret<tensor<8xf32>>) -> !secret.secret<tensor<8xf32>> {
%c2 = arith.constant dense<2.0> : tensor<8xf32>
%0 = secret.generic(%arg0: !secret.secret<tensor<8xf32>>) {
^body(%arg0_val: tensor<8xf32>):
%res = affine.for %i = 0 to 10 iter_args(%sum_iter = %c2) -> tensor<8xf32> {
%sum = arith.mulf %sum_iter, %arg0_val : tensor<8xf32>
affine.yield %sum : tensor<8xf32>
}
secret.yield %res : tensor<8xf32>
} -> !secret.secret<tensor<8xf32>>
return %0 : !secret.secret<tensor<8xf32>>
}The loop type invariance transforms convert this to the following before the mod-reduce insertion step
module attributes {backend.lattigo, scheme.ckks} {
func.func @loop_mul(%arg0: !secret.secret<tensor<8xf32>>">>}) -> !secret.secret<tensor<8xf32>> {
%cst = arith.constant dense<2.000000e+00> : tensor<8xf32>
%0 = secret.generic(%arg0: !secret.secret<tensor<8xf32>>) {
^body(%input0: tensor<8xf32>):
%1 = arith.mulf %cst, %input0 : tensor<8xf32>
%2 = mgmt.level_reduce_min %1 : tensor<8xf32>
%3 = affine.for %arg1 = 1 to 10 iter_args(%arg2 = %2) -> (tensor<8xf32>) {
%4 = mgmt.bootstrap %arg2 : tensor<8xf32>
%5 = arith.mulf %4, %input0 : tensor<8xf32>
%6 = mgmt.level_reduce_min %5 : tensor<8xf32>
affine.yield %6 : tensor<8xf32>
}
secret.yield %3 : tensor<8xf32>
} -> !secret.secret<tensor<8xf32>>
return %0 : !secret.secret<tensor<8xf32>>
}
}Note that here the mulf always operates at multiplicative depth 0, and so it is always "the first" mul op. In the case of the "before-mul" (excluding the first mul) mod-reduce insertion strategy, this would imply that NO modreduce ops are inserted into the loop body. This in turn causes the levels consumed by the loop to be zero, which in turn prevents loop unrolling (since the unroll factor divides the level budget by the number of consumed levels in one iteration).
While there are probably workarounds for this acute problem, it also makes me question why we have the "before mul not including first mul" mod-reduce insertion strategy in the first place. @ZenithalHourlyRate could probably share some wisdom, but to my understanding it's because you can avoid the first mod-reduce by encrypting a fresh ciphertext to the modded-down level. So my questions are:
- If that reason is accurate, does it also apply to bootstrap? Or just fresh encryptions?
- If this optimization is important to the overall performance a program (I suspect it may not be), then can we achieve the same result by using "before mul" (including first mul) and then having an optimization pass later in the pipeline that identifies when a freshly-encrypted ciphertext can be encrypted to a modded-down level (or a bootstrap can be bootstrapped to a slightly lower level), and replace the mod_reduce op with a corresponding change to the IR (e.g., a func arg attr for the client helper or a target_level of one less on the bootstrap op)?
In the meantime, I'm planning to punt on this, and the "bad outcome" is that a loop with a single mul in the body may not be properly unrolled. I think this will not affect our immediate deadlines, since we're focusing on CKKS and the strategy for CKKS is after-mul.