Skip to content

Commit dbecc04

Browse files
committed
Support calling custom method names via METHOD_TO_CALL (fixes triton-inference-server/server#5209)
1 parent 588c6ac commit dbecc04

File tree

3 files changed

+74
-22
lines changed

3 files changed

+74
-22
lines changed

README.md

+14
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,20 @@ complex execution modes and dynamic shapes. If not specified, all are enabled by
206206

207207
`ENABLE_TENSOR_FUSER`
208208

209+
* `METHOD_TO_CALL`: String flag to specify which method on the PyTorch model is being called.
210+
Default value is `forward`.
211+
212+
The section of model config file specifying this parameter will look like:
213+
214+
```
215+
parameters: {
216+
key: "METHOD_TO_CALL"
217+
value: {
218+
string_value:"true"
219+
}
220+
}
221+
```
222+
209223
### Important Note
210224

211225
* The execution of PyTorch model on GPU is asynchronous in nature. See

src/libtorch.cc

+59-21
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

2727
#include <stdint.h>
28+
2829
#include <exception>
30+
2931
#include "libtorch_utils.h"
3032
#include "triton/backend/backend_common.h"
3133
#include "triton/backend/backend_input_collector.h"
@@ -103,6 +105,7 @@ class ModelState : public BackendModel {
103105

104106
bool EnabledWeightSharing() { return enable_weight_sharing_; }
105107
const std::vector<std::string>& ModelOutputs() { return output_names_; }
108+
const std::string& MethodToCall() { return method_to_call_; }
106109

107110
private:
108111
ModelState(TRITONBACKEND_Model* triton_model);
@@ -145,6 +148,10 @@ class ModelState : public BackendModel {
145148
// List of all the outputs specified in the output section of model
146149
// configuration.
147150
std::vector<std::string> output_names_;
151+
152+
// Method to call on PyTorch Module.
153+
// Defaults to "forward".
154+
std::string method_to_call_;
148155
};
149156

150157
TRITONSERVER_Error*
@@ -180,7 +187,7 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
180187
enable_weight_sharing_(false), enable_tensor_fuser_pair_({false, true}),
181188
enable_jit_profiling_pair_({false, true}),
182189
enable_jit_executor_pair_({false, true}),
183-
enable_nvfuser_pair_({false, false})
190+
enable_nvfuser_pair_({false, false}), method_to_call_("forward")
184191
{
185192
output_names_.clear();
186193

@@ -454,6 +461,30 @@ ModelState::ParseParameters()
454461
" for model instance '" + Name() + "'")
455462
.c_str());
456463
}
464+
465+
// If 'ENABLE_NVFUSER' is not present in 'parameters' then no
466+
// update is made to 'enable_nvfuser'.
467+
std::string method_to_call = "forward";
468+
err = GetParameterValue(params, "METHOD_TO_CALL", &method_to_call);
469+
if (err != nullptr) {
470+
if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) {
471+
return err;
472+
} else {
473+
LOG_MESSAGE(
474+
TRITONSERVER_LOG_INFO,
475+
(std::string("method_to_call is not specified") +
476+
" for model instance '" + Name() + "'")
477+
.c_str());
478+
TRITONSERVER_ErrorDelete(err);
479+
}
480+
} else {
481+
method_to_call_ = method_to_call;
482+
LOG_MESSAGE(
483+
TRITONSERVER_LOG_INFO,
484+
(std::string("method_to_call is ") + method_to_call_ +
485+
" for model instance '" + Name() + "'")
486+
.c_str());
487+
}
457488
}
458489

459490
return nullptr;
@@ -764,7 +795,8 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
764795
// configuration specifies only those.
765796
std::vector<std::string> allowed_inputs;
766797

767-
const torch::jit::Method& method = torch_model_->get_method("forward");
798+
const torch::jit::Method& method =
799+
torch_model_->get_method(model_state_->MethodToCall());
768800
const auto& schema = method.function().getSchema();
769801
const std::vector<c10::Argument>& arguments = schema.arguments();
770802

@@ -1312,28 +1344,32 @@ ModelInstanceState::Execute(
13121344
torch::jit::overrideCanFuseOnCPU(false);
13131345
torch::jit::overrideCanFuseOnGPU(false);
13141346
torch::jit::setTensorExprFuserEnabled(false);
1315-
torch::jit::fuser::cuda::setEnabled(true);
1347+
torch::jit::fuser::cuda::setEnabled(true);
13161348
} else {
13171349
torch::jit::overrideCanFuseOnCPU(true);
13181350
torch::jit::overrideCanFuseOnGPU(true);
13191351
torch::jit::setTensorExprFuserEnabled(true);
1320-
torch::jit::fuser::cuda::setEnabled(false);
1352+
torch::jit::fuser::cuda::setEnabled(false);
13211353
}
13221354
}
13231355

13241356
torch::NoGradGuard no_grad;
13251357

13261358
// If input is a dictionary, prepare dictionary from 'input_tensors'.
1359+
std::string method_to_call = model_state_->MethodToCall();
13271360
if (is_dict_input_) {
1328-
torch::Dict<std::string, torch::Tensor> input_dict;
1361+
c10::Dict<std::string, at::Tensor> dict;
13291362
for (auto& input_index : input_index_map_) {
13301363
torch::jit::IValue ival = (*input_tensors)[input_index.second];
1331-
input_dict.insert(input_index.first, ival.toTensor());
1364+
dict.insert(input_index.first, ival.toTensor());
13321365
}
1333-
std::vector<torch::jit::IValue> input_dict_ivalue = {input_dict};
1334-
model_outputs_ = torch_model_->forward(input_dict_ivalue);
1366+
model_outputs_ = torch_model_->run_method(method_to_call, dict);
13351367
} else {
1336-
model_outputs_ = torch_model_->forward(*input_tensors);
1368+
auto inp = c10::impl::GenericList(c10::TensorType::get());
1369+
for (auto& input_tensor : *input_tensors) {
1370+
inp.emplace_back(input_tensor.toTensor());
1371+
}
1372+
model_outputs_ = torch_model_->run_method(method_to_call, inp);
13371373
}
13381374

13391375
if (model_outputs_.isTuple()) {
@@ -1761,9 +1797,9 @@ ModelInstanceState::SetInputTensors(
17611797

17621798
batchn_shape[0] += GetElementCount(input_shape, input_dims_count);
17631799
}
1764-
}
1765-
else {
1766-
batchn_shape = std::vector<int64_t>(input_shape, input_shape + input_dims_count);
1800+
} else {
1801+
batchn_shape =
1802+
std::vector<int64_t>(input_shape, input_shape + input_dims_count);
17671803
if (supports_batching_) {
17681804
batchn_shape[0] = total_batch_size;
17691805
}
@@ -1772,8 +1808,8 @@ ModelInstanceState::SetInputTensors(
17721808
// The input must be in contiguous CPU/GPU memory.
17731809
std::vector<std::pair<TRITONSERVER_MemoryType, int64_t>> alloc_perference;
17741810
if (device_.is_cpu()) {
1775-
alloc_perference = {{TRITONSERVER_MEMORY_CPU_PINNED, 0},
1776-
{TRITONSERVER_MEMORY_CPU, 0}};
1811+
alloc_perference = {
1812+
{TRITONSERVER_MEMORY_CPU_PINNED, 0}, {TRITONSERVER_MEMORY_CPU, 0}};
17771813
} else {
17781814
alloc_perference = {{TRITONSERVER_MEMORY_GPU, device_.index()}};
17791815
}
@@ -1887,9 +1923,11 @@ ModelInstanceState::ReadOutputTensors(
18871923

18881924
// Output tensors may not reside on the same device as model
18891925
torch::Device tensor_device = output_flat.device();
1890-
const auto memory_type = (tensor_device.type() == torch::kCPU) ? TRITONSERVER_MEMORY_CPU
1891-
: TRITONSERVER_MEMORY_GPU;
1892-
const auto memory_id = (tensor_device.type() == torch::kCPU) ? 0 : tensor_device.index();
1926+
const auto memory_type = (tensor_device.type() == torch::kCPU)
1927+
? TRITONSERVER_MEMORY_CPU
1928+
: TRITONSERVER_MEMORY_GPU;
1929+
const auto memory_id =
1930+
(tensor_device.type() == torch::kCPU) ? 0 : tensor_device.index();
18931931

18941932
// Batch output doesn't support string data type yet, as it is not trivial
18951933
// to parse string output
@@ -1906,16 +1944,16 @@ ModelInstanceState::ReadOutputTensors(
19061944
return TRITONSERVER_ErrorNew(
19071945
TRITONSERVER_ERROR_INVALID_ARG,
19081946
(std::string("output '") + name +
1909-
"' is a scalar which is not supported.")
1947+
"' is a scalar which is not supported.")
19101948
.c_str());
19111949
}
19121950

19131951
responder.ProcessTensor(
1914-
name, output_dtype, batchn_shape, output_buffer,
1915-
memory_type, memory_id);
1952+
name, output_dtype, batchn_shape, output_buffer, memory_type,
1953+
memory_id);
19161954
} else {
19171955
responder.ProcessBatchOutput(
1918-
name, *batch_output, output_buffer, memory_type, memory_id);
1956+
name, *batch_output, output_buffer, memory_type, memory_id);
19191957
}
19201958
} else if (output_tensors[op_index].isList()) {
19211959
// Custom handling for string/bytes tensor...

src/libtorch_utils.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ ParseParameter(
152152
#ifdef TRITON_ENABLE_GPU
153153
TRITONSERVER_Error*
154154
ConvertCUDAStatusToTritonError(
155-
cudaError_t cuda_error,TRITONSERVER_Error_Code code, const char* msg)
155+
cudaError_t cuda_error, TRITONSERVER_Error_Code code, const char* msg)
156156
{
157157
if (cuda_error != cudaSuccess) {
158158
return TRITONSERVER_ErrorNew(

0 commit comments

Comments
 (0)