Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port optimizations from HEaaN.mlir paper #635

Open
j2kun opened this issue Apr 20, 2024 · 4 comments
Open

Port optimizations from HEaaN.mlir paper #635

j2kun opened this issue Apr 20, 2024 · 4 comments
Labels
optimization research synthesis Reading papers to figure out which ideas can be incorporated

Comments

@j2kun
Copy link
Collaborator

j2kun commented Apr 20, 2024

https://dl.acm.org/doi/pdf/10.1145/3591228

The lowest hanging fruit for us seems to be loop fusion passes, which we could apply to the polynomial dialect after ntt is lowered to affine in the mlir-polynomial-to-llvm pipeline

@j2kun j2kun added optimization research synthesis Reading papers to figure out which ideas can be incorporated labels Apr 20, 2024
@inbelic
Copy link
Contributor

inbelic commented May 30, 2024

In the paper they also describe a potential optimization of modulo arithmetic by removing the use of a modulo operator (Rem[S|U]Op) for multiplication/addition of polynomials in the NTT domain. They introduce an operation to represent the first Barrett reduction step, say arith_ext.barrett_step, and, arith_ext.subifge x y that denotes z = (x >= y) ? x - y : y.
For instance:

mul = arith.muli x y
res = arith.remui mul cmod

becomes

mul = arith.muli x y
barret = arith_ext.barrett_reduce mul cmod
res = arith_ext.subifge barret cmod

We are able to avoid the use of potential division used in the remainder operation and only have runtime multiplication and bitshift as the Barrett ratio is able to be statically computed. We would be able to directly use this optimization in the current NTT lowering.

Further, they describe the use of a data-flow analysis to be able to reduce the number of arith_ext.subifge when there are subsequent uses from muls/adds/subs. These optimizations require the assumption that the input polynomial coefficients are in the range [0, cmod), however this restriction is not currently required in the polynomial (can be negative). This could be addressed with an operation poly.normalise that will ensure all the polynomial coefficients are in the range [0, cmod).

So I would propose the following steps to implement the papers optimizations:

  • Create a new dialect arith_ext which has the operations barret_reduce and subifge.
  • Create a pass that converts arith.muli + arith.remui into arith.muli + artih_ext.barret_reduce + arith_ext.subifge when the operands are in the range [0, cmod).
  • Create a pass that converts arith.addi + arith.remui into arith.addi + arith_ext.subifge when the operands are in the range [0, cmod).
  • Create a pass that converts arith.subi + arith.remui into arith.subi + arith.addi cmod + arith.subifge when the operands are in the range [0, cmod).
  • Add a new poly.normalise operation to poly to provide a fixed range [0, cod) for use in the above passes and modelling the ranges in the data-flow analysis.
  • Implement the papers data-flow analysis of reducing the number of arith.subifge operations.

Looking for feedback in all aspects of the proposed solution, especially operation names.

@asraa
Copy link
Collaborator

asraa commented May 31, 2024

Nice! I'm excited to see the difference when applied to the NTT lowering :)

as the Barrett ratio is able to be statically computed

Just to check me: for computing the Barret ratio, during poly-to-standard lowering the pass would statically compute the Barret ratio and insert the computation for computing the Barrett reduction, correct? If so, makes sense and I think I like the arith_ext style dialect and name. Since the modulus for arith_ext.barret_reduce must be constant and statically known, I wonder if it should be an attribute?

%barret = arith_ext.barrett_reduce {modulus = cmod} %mul

This could be addressed with an operation poly.normalise that will ensure all the polynomial coefficients are in the range [0, cmod).

Hmm yes that's a good point. I'm a little curious how just an operation will play out. Do we need an attribute on the polynomial type itself to mark that it is normalized? Without it I would think that a polynomial.mul lowering to standard would need to do some analysis to determine whether it's inputs were the result of a poly.normalise.

@j2kun
Copy link
Collaborator Author

j2kun commented May 31, 2024

require the assumption that the input polynomial coefficients are in the range [0, cmod)

Another possibility I was considering while writing #675 is that we should ensure this invariant holds always. I didn't ultimately do it in that PR because I found some confusing behavior around remsi/remui (either BOTH operands are signed or BOTH are unsigned, which is wrong both ways if you have (-1 : i32) % cmod).

But we could consider that. I think @AlexanderViand-Intel should chime in since this would have to be compatible with polynomial ISA considerations.

Otherwise I think this is a great plan.

@inbelic
Copy link
Contributor

inbelic commented May 31, 2024

Yes exactly, we can compute the ratio from half the operand bit-width and cmod. I agree with making it an attribute since it is static.

Hm good point. Thinking out loud: the optimizations would happen after polynomial-to-standard, where we would then apply the first set of passes to introduce the arith_ext operations. Followed by the data-flow analysis. We would then need to have a way to denote that the values of the tensor are normalized after lowering from the poly level.

We could use an encoding in the tensor to denote it is normalised wrt. some cmod. Then when we lower to_tensor we can mark those tensors, either all of them if the invariant holds or just the polynomials that have the attribute. Although I think we would still need some sort of analysis to propagate the ranges when we operate at the arith/arith_ext level.

Another option would be to introduce the poly.normalise op which would lower to an arith_ext.normalise { modulus = cmod } to allow for the range to be propagated a the arith/arith_ext level.

I can start now with adding the arith_ext dialect and their operations/passes. Then we have some time to think about how to go about the rest.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
optimization research synthesis Reading papers to figure out which ideas can be incorporated
Projects
None yet
Development

No branches or pull requests

3 participants