25
25
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
26
27
27
#include < stdint.h>
28
+
28
29
#include < exception>
30
+
29
31
#include " libtorch_utils.h"
30
32
#include " triton/backend/backend_common.h"
31
33
#include " triton/backend/backend_input_collector.h"
@@ -103,6 +105,7 @@ class ModelState : public BackendModel {
103
105
104
106
bool EnabledWeightSharing () { return enable_weight_sharing_; }
105
107
const std::vector<std::string>& ModelOutputs () { return output_names_; }
108
+ const std::string& MethodToCall () { return method_to_call_; }
106
109
107
110
private:
108
111
ModelState (TRITONBACKEND_Model* triton_model);
@@ -145,6 +148,10 @@ class ModelState : public BackendModel {
145
148
// List of all the outputs specified in the output section of model
146
149
// configuration.
147
150
std::vector<std::string> output_names_;
151
+
152
+ // Method to call on PyTorch Module.
153
+ // Defaults to "forward".
154
+ std::string method_to_call_;
148
155
};
149
156
150
157
TRITONSERVER_Error*
@@ -180,7 +187,7 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
180
187
enable_weight_sharing_(false ), enable_tensor_fuser_pair_({false , true }),
181
188
enable_jit_profiling_pair_({false , true }),
182
189
enable_jit_executor_pair_({false , true }),
183
- enable_nvfuser_pair_({false , false })
190
+ enable_nvfuser_pair_({false , false }), method_to_call_( " forward " )
184
191
{
185
192
output_names_.clear ();
186
193
@@ -454,6 +461,30 @@ ModelState::ParseParameters()
454
461
" for model instance '" + Name () + " '" )
455
462
.c_str ());
456
463
}
464
+
465
+ // If 'ENABLE_NVFUSER' is not present in 'parameters' then no
466
+ // update is made to 'enable_nvfuser'.
467
+ std::string method_to_call = " forward" ;
468
+ err = GetParameterValue (params, " METHOD_TO_CALL" , &method_to_call);
469
+ if (err != nullptr ) {
470
+ if (TRITONSERVER_ErrorCode (err) != TRITONSERVER_ERROR_NOT_FOUND) {
471
+ return err;
472
+ } else {
473
+ LOG_MESSAGE (
474
+ TRITONSERVER_LOG_INFO,
475
+ (std::string (" method_to_call is not specified" ) +
476
+ " for model instance '" + Name () + " '" )
477
+ .c_str ());
478
+ TRITONSERVER_ErrorDelete (err);
479
+ }
480
+ } else {
481
+ method_to_call_ = method_to_call;
482
+ LOG_MESSAGE (
483
+ TRITONSERVER_LOG_INFO,
484
+ (std::string (" method_to_call is " ) + method_to_call_ +
485
+ " for model instance '" + Name () + " '" )
486
+ .c_str ());
487
+ }
457
488
}
458
489
459
490
return nullptr ;
@@ -764,7 +795,8 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
764
795
// configuration specifies only those.
765
796
std::vector<std::string> allowed_inputs;
766
797
767
- const torch::jit::Method& method = torch_model_->get_method (" forward" );
798
+ const torch::jit::Method& method =
799
+ torch_model_->get_method (model_state_->MethodToCall ());
768
800
const auto & schema = method.function ().getSchema ();
769
801
const std::vector<c10::Argument>& arguments = schema.arguments ();
770
802
@@ -1312,28 +1344,32 @@ ModelInstanceState::Execute(
1312
1344
torch::jit::overrideCanFuseOnCPU (false );
1313
1345
torch::jit::overrideCanFuseOnGPU (false );
1314
1346
torch::jit::setTensorExprFuserEnabled (false );
1315
- torch::jit::fuser::cuda::setEnabled (true );
1347
+ torch::jit::fuser::cuda::setEnabled (true );
1316
1348
} else {
1317
1349
torch::jit::overrideCanFuseOnCPU (true );
1318
1350
torch::jit::overrideCanFuseOnGPU (true );
1319
1351
torch::jit::setTensorExprFuserEnabled (true );
1320
- torch::jit::fuser::cuda::setEnabled (false );
1352
+ torch::jit::fuser::cuda::setEnabled (false );
1321
1353
}
1322
1354
}
1323
1355
1324
1356
torch::NoGradGuard no_grad;
1325
1357
1326
1358
// If input is a dictionary, prepare dictionary from 'input_tensors'.
1359
+ std::string method_to_call = model_state_->MethodToCall ();
1327
1360
if (is_dict_input_) {
1328
- torch ::Dict<std::string, torch ::Tensor> input_dict ;
1361
+ c10 ::Dict<std::string, at ::Tensor> dict ;
1329
1362
for (auto & input_index : input_index_map_) {
1330
1363
torch::jit::IValue ival = (*input_tensors)[input_index.second ];
1331
- input_dict .insert (input_index.first , ival.toTensor ());
1364
+ dict .insert (input_index.first , ival.toTensor ());
1332
1365
}
1333
- std::vector<torch::jit::IValue> input_dict_ivalue = {input_dict};
1334
- model_outputs_ = torch_model_->forward (input_dict_ivalue);
1366
+ model_outputs_ = torch_model_->run_method (method_to_call, dict);
1335
1367
} else {
1336
- model_outputs_ = torch_model_->forward (*input_tensors);
1368
+ auto inp = c10::impl::GenericList (c10::TensorType::get ());
1369
+ for (auto & input_tensor : *input_tensors) {
1370
+ inp.emplace_back (input_tensor.toTensor ());
1371
+ }
1372
+ model_outputs_ = torch_model_->run_method (method_to_call, inp);
1337
1373
}
1338
1374
1339
1375
if (model_outputs_.isTuple ()) {
@@ -1761,9 +1797,9 @@ ModelInstanceState::SetInputTensors(
1761
1797
1762
1798
batchn_shape[0 ] += GetElementCount (input_shape, input_dims_count);
1763
1799
}
1764
- }
1765
- else {
1766
- batchn_shape = std::vector<int64_t >(input_shape, input_shape + input_dims_count);
1800
+ } else {
1801
+ batchn_shape =
1802
+ std::vector<int64_t >(input_shape, input_shape + input_dims_count);
1767
1803
if (supports_batching_) {
1768
1804
batchn_shape[0 ] = total_batch_size;
1769
1805
}
@@ -1772,8 +1808,8 @@ ModelInstanceState::SetInputTensors(
1772
1808
// The input must be in contiguous CPU/GPU memory.
1773
1809
std::vector<std::pair<TRITONSERVER_MemoryType, int64_t >> alloc_perference;
1774
1810
if (device_.is_cpu ()) {
1775
- alloc_perference = {{TRITONSERVER_MEMORY_CPU_PINNED, 0 },
1776
- {TRITONSERVER_MEMORY_CPU, 0 }};
1811
+ alloc_perference = {
1812
+ {TRITONSERVER_MEMORY_CPU_PINNED, 0 }, {TRITONSERVER_MEMORY_CPU, 0 }};
1777
1813
} else {
1778
1814
alloc_perference = {{TRITONSERVER_MEMORY_GPU, device_.index ()}};
1779
1815
}
@@ -1887,9 +1923,11 @@ ModelInstanceState::ReadOutputTensors(
1887
1923
1888
1924
// Output tensors may not reside on the same device as model
1889
1925
torch::Device tensor_device = output_flat.device ();
1890
- const auto memory_type = (tensor_device.type () == torch::kCPU ) ? TRITONSERVER_MEMORY_CPU
1891
- : TRITONSERVER_MEMORY_GPU;
1892
- const auto memory_id = (tensor_device.type () == torch::kCPU ) ? 0 : tensor_device.index ();
1926
+ const auto memory_type = (tensor_device.type () == torch::kCPU )
1927
+ ? TRITONSERVER_MEMORY_CPU
1928
+ : TRITONSERVER_MEMORY_GPU;
1929
+ const auto memory_id =
1930
+ (tensor_device.type () == torch::kCPU ) ? 0 : tensor_device.index ();
1893
1931
1894
1932
// Batch output doesn't support string data type yet, as it is not trivial
1895
1933
// to parse string output
@@ -1906,16 +1944,16 @@ ModelInstanceState::ReadOutputTensors(
1906
1944
return TRITONSERVER_ErrorNew (
1907
1945
TRITONSERVER_ERROR_INVALID_ARG,
1908
1946
(std::string (" output '" ) + name +
1909
- " ' is a scalar which is not supported." )
1947
+ " ' is a scalar which is not supported." )
1910
1948
.c_str ());
1911
1949
}
1912
1950
1913
1951
responder.ProcessTensor (
1914
- name, output_dtype, batchn_shape, output_buffer,
1915
- memory_type, memory_id);
1952
+ name, output_dtype, batchn_shape, output_buffer, memory_type,
1953
+ memory_id);
1916
1954
} else {
1917
1955
responder.ProcessBatchOutput (
1918
- name, *batch_output, output_buffer, memory_type, memory_id);
1956
+ name, *batch_output, output_buffer, memory_type, memory_id);
1919
1957
}
1920
1958
} else if (output_tensors[op_index].isList ()) {
1921
1959
// Custom handling for string/bytes tensor...
0 commit comments