Skip to content

Commit 3631e0d

Browse files
committed
Add guards to prevent lowering eval-form polynomial constants
1 parent 34cf0bf commit 3631e0d

3 files changed

Lines changed: 99 additions & 0 deletions

File tree

lib/Dialect/Polynomial/Conversions/PolynomialToModArith/PolynomialToModArith.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,11 @@ struct ConvertConstant : public OpConversionPattern<ConstantOp> {
298298
op, "failed to construct common conversion info");
299299

300300
auto typeInfo = res.value();
301+
// TODO(#97): support compile-time NTT
302+
if (typeInfo.polynomialType.getForm() == Form::EVAL) {
303+
return rewriter.notifyMatchFailure(
304+
op, "unsupported eval-form polynomial constant");
305+
}
301306

302307
auto attr = dyn_cast<TypedIntPolynomialAttr>(op.getValue());
303308
if (!attr)
@@ -393,6 +398,45 @@ struct ConvertMonomial : public OpConversionPattern<MonomialOp> {
393398
auto storageTensorType =
394399
RankedTensorType::get(storageShape, typeInfo.coefficientStorageType);
395400

401+
// TODO(#97): support compile-time NTT
402+
// We don't have proper support for EVAL-form constants, but we can
403+
// at least support degree-zero polynomial constants in EVAL form. The
404+
// NTT of a degree-zero polynomial is a vector where each coefficient is the
405+
// constant term.
406+
if (typeInfo.polynomialType.getForm() == Form::EVAL) {
407+
auto degree = adaptor.getDegree().getDefiningOp<arith::ConstantIndexOp>();
408+
if (!degree || degree.value() != 0) {
409+
return rewriter.notifyMatchFailure(
410+
op, "unsupported eval-form non-constant monomial");
411+
}
412+
413+
Value result;
414+
if (auto modQTy = dyn_cast<ModQTypeInterface>(typeInfo.coefficientType)) {
415+
Type extractedType = modQTy.getLoweringType();
416+
Value extracted = mod_arith::ExtractOp::create(
417+
b, extractedType, adaptor.getCoefficient());
418+
Value replicatedStorage;
419+
if (isa<ShapedType>(extractedType)) {
420+
auto init = tensor::EmptyOp::create(b, storageTensorType.getShape(),
421+
typeInfo.coefficientStorageType);
422+
replicatedStorage = linalg::BroadcastOp::create(b, extracted, init,
423+
ArrayRef<int64_t>{0})
424+
.getResult()[0];
425+
} else {
426+
replicatedStorage =
427+
tensor::SplatOp::create(b, extracted, storageTensorType);
428+
}
429+
result = mod_arith::EncapsulateOp::create(b, typeInfo.tensorType,
430+
replicatedStorage)
431+
.getResult();
432+
} else {
433+
result = tensor::SplatOp::create(b, adaptor.getCoefficient(),
434+
typeInfo.tensorType);
435+
}
436+
rewriter.replaceOp(op, result);
437+
return success();
438+
}
439+
396440
auto tensor = arith::ConstantOp::create(
397441
b, DenseElementsAttr::get(
398442
storageTensorType,
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// RUN: heir-opt --polynomial-to-mod-arith --verify-diagnostics --split-input-file %s
2+
3+
#poly = #polynomial.int_polynomial<1 + x**4>
4+
#ring = #polynomial.ring<coefficientType=i32, polynomialModulus=#poly>
5+
!poly_ty = !polynomial.polynomial<ring=#ring, form=eval>
6+
7+
func.func @eval_constant() -> !poly_ty {
8+
// expected-error@+1 {{failed to legalize operation}}
9+
%0 = polynomial.constant int<1> : !poly_ty
10+
return %0 : !poly_ty
11+
}
12+
13+
// -----
14+
15+
#poly = #polynomial.int_polynomial<1 + x**4>
16+
#ring = #polynomial.ring<coefficientType=i32, polynomialModulus=#poly>
17+
!poly_ty = !polynomial.polynomial<ring=#ring, form=eval>
18+
19+
func.func @eval_nonconstant_monomial(%coeff: i32) -> !poly_ty {
20+
%c1 = arith.constant 1 : index
21+
// expected-error@+1 {{failed to legalize operation}}
22+
%0 = polynomial.monomial %coeff, %c1 : (i32, index) -> !poly_ty
23+
return %0 : !poly_ty
24+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// RUN: heir-opt --polynomial-to-mod-arith --mlir-print-local-scope %s | FileCheck %s
2+
3+
#poly = #polynomial.int_polynomial<1 + x**4>
4+
#ring_i32 = #polynomial.ring<coefficientType=i32, polynomialModulus=#poly>
5+
!poly_i32 = !polynomial.polynomial<ring=#ring_i32, form=eval>
6+
7+
// CHECK: func @eval_constant_monomial_int
8+
// CHECK-SAME: (%[[ARG0:.*]]: i32)
9+
func.func @eval_constant_monomial_int(%coeff: i32) -> !poly_i32 {
10+
%c0 = arith.constant 0 : index
11+
// CHECK: %[[SPLAT:.*]] = tensor.splat %[[ARG0]] : tensor<4xi32>
12+
// CHECK: return %[[SPLAT]]
13+
%0 = polynomial.monomial %coeff, %c0 : (i32, index) -> !poly_i32
14+
return %0 : !poly_i32
15+
}
16+
17+
!coeff_ty = !mod_arith.int<17 : i32>
18+
#ring_mod = #polynomial.ring<coefficientType=!coeff_ty, polynomialModulus=#poly>
19+
!poly_mod = !polynomial.polynomial<ring=#ring_mod, form=eval>
20+
21+
// CHECK: func @eval_constant_monomial_mod
22+
// CHECK-SAME: (%[[ARG0:.*]]: !mod_arith.int<17 : i32>)
23+
func.func @eval_constant_monomial_mod(%coeff: !coeff_ty) -> !poly_mod {
24+
%c0 = arith.constant 0 : index
25+
// CHECK: %[[EXTRACTED:.*]] = mod_arith.extract %[[ARG0]] : !mod_arith.int<17 : i32> -> i32
26+
// CHECK: %[[SPLAT:.*]] = tensor.splat %[[EXTRACTED]] : tensor<4xi32>
27+
// CHECK: %[[ENCAPSULATED:.*]] = mod_arith.encapsulate %[[SPLAT]] : tensor<4xi32> -> tensor<4x!mod_arith.int<17 : i32>>
28+
// CHECK: return %[[ENCAPSULATED]]
29+
%0 = polynomial.monomial %coeff, %c0 : (!coeff_ty, index) -> !poly_mod
30+
return %0 : !poly_mod
31+
}

0 commit comments

Comments
 (0)