Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions shardy/dialect/sdy/ir/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/ScopedPrinter.h"
#include "llvm/Support/Threading.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand All @@ -43,6 +44,7 @@ limitations under the License.
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Support/LLVM.h"
#include "shardy/common/logging.h"
#include "shardy/dialect/sdy/ir/constants.h"
#include "shardy/dialect/sdy/ir/dialect.h"

Expand Down Expand Up @@ -318,13 +320,6 @@ Value getShardableValue(Value value) {
return shardableRegionOp.getEdgeOwnerFromTarget(value);
})
.Default([&](Operation* op) {
// We only fail if the value isn't scalar. Scalar block arguments, such
// as the arguments of a reduction function, don't have a shardable
// value. This is ok since they are scalars (rank 0) and therefore can't
// be sharded.
if (!isScalar(value)) {
unreachableFormatv("region op '{0}' not supported", op->getName());
}
return nullptr;
});
}
Expand Down Expand Up @@ -386,7 +381,9 @@ TensorShardingAttr getOrCreateSharding(Value value, StringRef meshName,

void setSharding(Value value, TensorShardingAttr sharding) {
value = getShardableValue(value);
assert(value && "value should exist if its sharding is updated");
SDY_CHECK(value)
<< "value should be shardable if its sharding is updated, got: "
<< llvm::to_string(value);
TypeSwitch<Operation*>(getOwningOp(value))
.Case<FuncOp>([&](FuncOp funcOp) {
funcOp.setArgAttr(cast<BlockArgument>(value).getArgNumber(),
Expand Down
Loading