diff --git a/shardy/dialect/sdy/ir/utils.cc b/shardy/dialect/sdy/ir/utils.cc index b460bb62f..1591d5d1e 100644 --- a/shardy/dialect/sdy/ir/utils.cc +++ b/shardy/dialect/sdy/ir/utils.cc @@ -27,6 +27,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/ScopedPrinter.h" #include "llvm/Support/Threading.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -43,6 +44,7 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" +#include "shardy/common/logging.h" #include "shardy/dialect/sdy/ir/constants.h" #include "shardy/dialect/sdy/ir/dialect.h" @@ -318,13 +320,6 @@ Value getShardableValue(Value value) { return shardableRegionOp.getEdgeOwnerFromTarget(value); }) .Default([&](Operation* op) { - // We only fail if the value isn't scalar. Scalar block arguments, such - // as the arguments of a reduction function, don't have a shardable - // value. This is ok since they are scalars (rank 0) and therefore can't - // be sharded. - if (!isScalar(value)) { - unreachableFormatv("region op '{0}' not supported", op->getName()); - } return nullptr; }); } @@ -386,7 +381,9 @@ TensorShardingAttr getOrCreateSharding(Value value, StringRef meshName, void setSharding(Value value, TensorShardingAttr sharding) { value = getShardableValue(value); - assert(value && "value should exist if its sharding is updated"); + SDY_CHECK(value) + << "value should be shardable if its sharding is updated, got: " + << llvm::to_string(value); TypeSwitch(getOwningOp(value)) .Case([&](FuncOp funcOp) { funcOp.setArgAttr(cast(value).getArgNumber(),