diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 985b1e1edbb..803183cab3b 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -2897,50 +2897,65 @@ void InitXlaModuleBindings(py::module m) { // -------------Dynamo Integration API Start------------------------- /* * Return tensor ids and at::tensors for all DeviceData nodes that is needed - * to compute the value of tensors. + * to compute the value of tensors. In case the input tensors are provided, + * we ensure that the returning ID and IValue are retained, in order to avoid + * implicitly creating a new XLA Tensor (with a new unique tensor ID). */ - m.def("_get_tensors_xla_device_data_node", - [](const std::vector& tensors) - -> std::pair, std::vector> { - std::vector tensor_ids; - std::vector ivalues; - std::vector roots; - for (const at::Tensor& tensor : tensors) { - auto xtensor = bridge::TryGetXlaTensor(tensor); - if (xtensor) { - roots.push_back(xtensor->GetIrValue().node.get()); - } + m.def( + "_get_tensors_xla_device_data_node", + [](const std::vector& output_tensors, + const std::vector& input_tensors) + -> std::pair, std::vector> { + std::vector roots; + for (const at::Tensor& tensor : output_tensors) { + auto xtensor = bridge::TryGetXlaTensor(tensor); + if (xtensor) { + roots.push_back(xtensor->GetIrValue().node.get()); } - auto post_order = torch::lazy::Util::ComputePostOrder(roots); - std::unordered_set data_handles; - - for (const torch::lazy::Node* nodeptr : post_order) { - const auto backend_data = - torch::lazy::getBackend()->GetComputationDataFromNode(nodeptr); - if (!backend_data) { - continue; - } + } - // Dedup by handle - torch::lazy::BackendData::Handle handle = backend_data->GetHandle(); - if (!data_handles.insert(handle).second) { - continue; - } - auto* infoptr = - static_cast( - backend_data->info()); - if (infoptr) { - tensor_ids.push_back(infoptr->tensor_id); - } else { - // TODO(JackCaoG): Make sure this device data is actually seed. - tensor_ids.push_back(seed_info_id); - } + std::unordered_map input_tensor_map; + input_tensor_map.reserve(input_tensors.size()); + for (const at::Tensor& tensor : input_tensors) { + int64_t tensor_id = GetTensorId(tensor); + input_tensor_map[tensor_id] = tensor; + } + + auto post_order = torch::lazy::Util::ComputePostOrder(roots); + std::unordered_set data_handles; + + std::vector tensor_ids; + std::vector ivalues; + for (const torch::lazy::Node* nodeptr : post_order) { + const auto backend_data = + torch::lazy::getBackend()->GetComputationDataFromNode(nodeptr); + if (!backend_data) { + continue; + } + + // Dedup by handle + torch::lazy::BackendData::Handle handle = backend_data->GetHandle(); + if (!data_handles.insert(handle).second) { + continue; + } + auto* infoptr = + static_cast( + backend_data->info()); + + // TODO(JackCaoG): Make sure this device data is actually seed. + int64_t tensor_id = infoptr ? infoptr->tensor_id : seed_info_id; + tensor_ids.push_back(tensor_id); + if (input_tensor_map.find(tensor_id) != input_tensor_map.end()) { + ivalues.emplace_back(input_tensor_map[tensor_id]); + } else { at::Tensor tensor = bridge::AtenFromXlaTensor( torch_xla::XLATensor::Create(backend_data)); ivalues.emplace_back(tensor); } - return std::make_pair(tensor_ids, ivalues); - }); + } + return std::make_pair(tensor_ids, ivalues); + }, + py::arg("input_tensors"), py::arg("output_tensors") = py::list()); m.def("_get_seed_info_id", []() -> int64_t { return seed_info_id; });