@@ -298,6 +298,11 @@ struct ConvertConstant : public OpConversionPattern<ConstantOp> {
298298 op, " failed to construct common conversion info" );
299299
300300 auto typeInfo = res.value ();
301+ // TODO(#97): support compile-time NTT
302+ if (typeInfo.polynomialType .getForm () == Form::EVAL ) {
303+ return rewriter.notifyMatchFailure (
304+ op, " unsupported eval-form polynomial constant" );
305+ }
301306
302307 auto attr = dyn_cast<TypedIntPolynomialAttr>(op.getValue ());
303308 if (!attr)
@@ -393,6 +398,45 @@ struct ConvertMonomial : public OpConversionPattern<MonomialOp> {
393398 auto storageTensorType =
394399 RankedTensorType::get (storageShape, typeInfo.coefficientStorageType );
395400
401+ // TODO(#97): support compile-time NTT
402+ // We don't have proper support for EVAL-form constants, but we can
403+ // at least support degree-zero polynomial constants in EVAL form. The
404+ // NTT of a degree-zero polynomial is a vector where each coefficient is the
405+ // constant term.
406+ if (typeInfo.polynomialType .getForm () == Form::EVAL ) {
407+ auto degree = adaptor.getDegree ().getDefiningOp <arith::ConstantIndexOp>();
408+ if (!degree || degree.value () != 0 ) {
409+ return rewriter.notifyMatchFailure (
410+ op, " unsupported eval-form non-constant monomial" );
411+ }
412+
413+ Value result;
414+ if (auto modQTy = dyn_cast<ModQTypeInterface>(typeInfo.coefficientType )) {
415+ Type extractedType = modQTy.getLoweringType ();
416+ Value extracted = mod_arith::ExtractOp::create (
417+ b, extractedType, adaptor.getCoefficient ());
418+ Value replicatedStorage;
419+ if (isa<ShapedType>(extractedType)) {
420+ auto init = tensor::EmptyOp::create (b, storageTensorType.getShape (),
421+ typeInfo.coefficientStorageType );
422+ replicatedStorage = linalg::BroadcastOp::create (b, extracted, init,
423+ ArrayRef<int64_t >{0 })
424+ .getResult ()[0 ];
425+ } else {
426+ replicatedStorage =
427+ tensor::SplatOp::create (b, extracted, storageTensorType);
428+ }
429+ result = mod_arith::EncapsulateOp::create (b, typeInfo.tensorType ,
430+ replicatedStorage)
431+ .getResult ();
432+ } else {
433+ result = tensor::SplatOp::create (b, adaptor.getCoefficient (),
434+ typeInfo.tensorType );
435+ }
436+ rewriter.replaceOp (op, result);
437+ return success ();
438+ }
439+
396440 auto tensor = arith::ConstantOp::create (
397441 b, DenseElementsAttr::get (
398442 storageTensorType,
0 commit comments