Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion python/tvm/relax/frontend/torch/exported_program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ def _convert_pytorch_tensor_to_tvm(tensor_value: torch.Tensor) -> tvm.runtime.Te
-------
tvm.runtime.Tensor
The converted TVM tensor.

Raises
------
RuntimeError
If the tensor is a FakeTensor or other tensor subclass that cannot be converted.
"""
# PyTorch sparse tensors (layout != torch.strided) must be converted to dense.
if tensor_value.layout != torch.strided:
Expand Down Expand Up @@ -1688,11 +1693,27 @@ def from_exported_program(
binding = {}
for tensor_name, tensor_value in to_bind_parameters.items():
# find relax var name from graph signature
bind_name = None
for spec in exported_program.graph_signature.input_specs:
if tensor_name == spec.target:
bind_name = spec.arg.name
break
binding[bind_name] = self._convert_pytorch_tensor_to_tvm(tensor_value)
if bind_name is None:
# Skip tensors that don't have corresponding input specs
# (e.g., lifted_tensor from torch.export)
continue
try:
binding[bind_name] = self._convert_pytorch_tensor_to_tvm(tensor_value)
except RuntimeError as e:
# Skip FakeTensor/lifted tensors that cannot be converted
# These are typically intermediate tensors that torch.export couldn't properly lift
import warnings

warnings.warn(
f"Skipping parameter '{tensor_name}' (bind_name: '{bind_name}'): "
f"Cannot convert tensor to TVM format: {e}"
)
continue

mod = self.block_builder.get()
mod = relax.transform.BindParams("main", binding)(mod)
Expand Down
192 changes: 175 additions & 17 deletions src/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@
#include <tvm/relax/expr.h>
#include <tvm/runtime/device_api.h>

#include <optional>
#include <queue>
#include <sstream>
#include <stack>
#include <unordered_map>
#include <unordered_set>
#include <vector>

namespace tvm {
namespace transform {
Expand Down Expand Up @@ -443,15 +449,6 @@ const SequentialNode* Sequential::operator->() const {
return static_cast<const SequentialNode*>(get());
}

void SequentialNode::ResolveDependency(const IRModule& mod) {
// TODO(zhiics) Implement it.
// 1. Consider the required passes for each pass.
// 2. Only resolve the enabled passes.
// 3. Build a dependency graph. Probably we need to update the pass list.
LOG(FATAL) << "Pass dependency has not been resolved yet."
<< "\n";
}

Pass GetPass(const ffi::String& pass_name) {
std::optional<tvm::ffi::Function> f;
if (pass_name.operator std::string().find("transform.") != std::string::npos) {
Expand All @@ -463,10 +460,174 @@ Pass GetPass(const ffi::String& pass_name) {
return (*f)().cast<Pass>();
}

// TODO(zhiics): we currently only sequentially execute each pass in
// a Sequential without the consideration of their orders. The phase
// ordering problem needs to be handled in the future.
// Safe version of GetPass that returns empty optional instead of throwing
std::optional<Pass> TryGetPass(const ffi::String& pass_name) {
std::optional<tvm::ffi::Function> f;
if (pass_name.operator std::string().find("transform.") != std::string::npos) {
f = tvm::ffi::Function::GetGlobal(pass_name);
} else {
f = tvm::ffi::Function::GetGlobal("transform." + pass_name);
}
if (!f.has_value()) {
return std::nullopt;
}
return (*f)().cast<Pass>();
}

void SequentialNode::ResolveDependency(const IRModule& mod) {
// Get the current pass context to check which passes are enabled
// Note: mod parameter is reserved for future use when dependency resolution
// might need to consider module-specific information
(void)mod; // Suppress unused parameter warning
PassContext pass_ctx = PassContext::Current();

// Step 1: Collect all enabled passes from the current list
std::unordered_map<std::string, Pass> name_to_pass;
std::vector<Pass> enabled_passes;

for (const Pass& pass : passes) {
if (!pass.defined()) {
continue;
}
const PassInfo& pass_info = pass->Info();
if (pass_ctx.PassEnabled(pass_info)) {
std::string pass_name = pass_info->name;
// Avoid duplicates
if (name_to_pass.find(pass_name) == name_to_pass.end()) {
name_to_pass[pass_name] = pass;
enabled_passes.push_back(pass);
}
}
}

// Step 2: Collect all required passes that are not in the current list
// We need to do this in multiple passes to handle transitive dependencies
std::unordered_set<std::string> processed_required;
bool changed = true;
while (changed) {
changed = false;
for (size_t i = 0; i < enabled_passes.size(); ++i) {
const PassInfo& pass_info = enabled_passes[i]->Info();
for (const auto& required_name : pass_info->required) {
std::string req_name = required_name;
std::string key = pass_info->name + "->" + req_name;
if (processed_required.find(key) != processed_required.end()) {
continue;
}
processed_required.insert(key);

// Check if the required pass is already in our list
if (name_to_pass.find(req_name) == name_to_pass.end()) {
// Try to get it from the global registry
// Use TryGetPass to avoid exceptions when the pass is not registered
std::optional<Pass> required_pass_opt = TryGetPass(ffi::String(req_name));
if (required_pass_opt.has_value()) {
Pass required_pass = required_pass_opt.value();
const PassInfo& req_pass_info = required_pass->Info();
if (pass_ctx.PassEnabled(req_pass_info)) {
name_to_pass[req_name] = required_pass;
enabled_passes.push_back(required_pass);
changed = true;
}
} else {
// If we can't get the pass from the registry, we'll skip this dependency
// This can happen if the required pass is not registered globally
// It will be resolved at runtime in operator() if needed
VLOG(0) << "Warning: Cannot resolve required pass '" << req_name << "' for pass '"
<< pass_info->name
<< "' from global registry. It will be resolved at runtime if needed.";
}
}
}
}
}

// Step 3: Build dependency graph
// Map from pass name to its index in enabled_passes
std::unordered_map<std::string, size_t> name_to_index;
for (size_t i = 0; i < enabled_passes.size(); ++i) {
const PassInfo& pass_info = enabled_passes[i]->Info();
name_to_index[pass_info->name] = i;
}

// Build reverse adjacency list: dependents[i] contains indices of passes that depend on pass i
// This is used for topological sort
std::vector<std::vector<size_t>> dependents(enabled_passes.size());
std::vector<size_t> in_degree(enabled_passes.size(), 0);

for (size_t i = 0; i < enabled_passes.size(); ++i) {
const PassInfo& pass_info = enabled_passes[i]->Info();
for (const auto& required_name : pass_info->required) {
std::string req_name = required_name;
auto it = name_to_index.find(req_name);
if (it != name_to_index.end()) {
// The required pass is in our enabled passes list
// pass i depends on pass req_idx, so req_idx should come before i
size_t req_idx = it->second;
dependents[req_idx].push_back(i);
in_degree[i]++;
}
// If the required pass is not in our list, it will be handled at runtime
}
}

// Step 4: Topological sort using Kahn's algorithm
std::queue<size_t> queue;
for (size_t i = 0; i < enabled_passes.size(); ++i) {
if (in_degree[i] == 0) {
queue.push(i);
}
}

std::vector<Pass> sorted_passes;
// Track which passes have been sorted to handle circular dependencies
std::vector<bool> sorted(enabled_passes.size(), false);

while (!queue.empty()) {
size_t current = queue.front();
queue.pop();

// In Kahn's algorithm, a node is added to queue only when in_degree becomes 0,
// which happens exactly once for each node in a DAG, so no need to check visited
sorted_passes.push_back(enabled_passes[current]);
sorted[current] = true;

// Process dependents: passes that depend on the current pass
for (size_t dependent : dependents[current]) {
in_degree[dependent]--;
if (in_degree[dependent] == 0) {
queue.push(dependent);
}
}
}

// Check for circular dependencies
if (sorted_passes.size() != enabled_passes.size()) {
std::ostringstream os;
os << "Circular dependency detected in pass sequence. "
<< "Only " << sorted_passes.size() << " out of " << enabled_passes.size()
<< " passes were sorted. Remaining passes will be appended in original order.";
LOG(WARNING) << os.str();
// Add remaining passes that weren't sorted (they have circular dependencies)
for (size_t i = 0; i < enabled_passes.size(); ++i) {
if (!sorted[i]) {
sorted_passes.push_back(enabled_passes[i]);
}
}
}

// Step 5: Update the passes list
passes = ffi::Array<Pass>(sorted_passes);
}

IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) const {
// Resolve dependencies and sort passes using topological sort
// Note: We need to call ResolveDependency which modifies the passes member,
// but since SequentialNode is an Object (immutable reference), we can safely
// modify it here as the actual object data is mutable.
const_cast<SequentialNode*>(this)->ResolveDependency(mod);

// Execute passes in the resolved order
for (const Pass& pass : passes) {
VLOG(0) << "Running pass " << pass->Info()->name;
ICHECK(pass.defined()) << "Found undefined pass for optimization.";
Expand All @@ -476,11 +637,8 @@ IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) c
continue;
}

// resolve dependencies
for (const auto& it : pass_info->required) {
mod = GetPass(it)(std::move(mod), pass_ctx);
}

// Dependencies are already resolved and sorted by ResolveDependency,
// so we just execute the pass directly
mod = pass(std::move(mod), pass_ctx);
}
return mod;
Expand Down
Loading