Skip to content

HeterogeneousCore/SonicTriton: add RetryActionDiffServer; expose connectToServer; update tests #21

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions HeterogeneousCore/SonicCore/BuildFile.xml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
<use name="FWCore/Concurrency"/>
<use name="FWCore/MessageLogger"/>
<use name="FWCore/ParameterSet"/>
<use name="FWCore/PluginManager"/>
<use name="FWCore/Utilities"/>
<export>
<lib name="1"/>
Expand Down
35 changes: 35 additions & 0 deletions HeterogeneousCore/SonicCore/interface/RetryActionBase.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#ifndef HeterogeneousCore_SonicCore_RetryActionBase
#define HeterogeneousCore_SonicCore_RetryActionBase

#include "FWCore/PluginManager/interface/PluginFactory.h"
#include "FWCore/ParameterSet/interface/ParameterSet.h"
#include "HeterogeneousCore/SonicCore/interface/SonicClientBase.h"
#include <memory>
#include <string>

// Base class for retry actions
class RetryActionBase {
public:
RetryActionBase(const edm::ParameterSet& conf, SonicClientBase* client);
virtual ~RetryActionBase() = default;

bool shouldRetry() const { return shouldRetry_; } // Getter for shouldRetry_

virtual void retry() = 0; // Pure virtual function for execution logic
virtual void start() = 0; // Pure virtual function for execution logic for initialization

protected:
void eval(); // interface for calling evaluate in client

protected:
SonicClientBase* client_;
bool shouldRetry_; // Flag to track if further retries should happen
};

// Define the factory for creating retry actions
using RetryActionFactory =
edmplugin::PluginFactory<RetryActionBase*(const edm::ParameterSet&, SonicClientBase* client)>;

#endif

#define DEFINE_RETRY_ACTION(type) DEFINE_EDM_PLUGIN(RetryActionFactory, type, #type);
14 changes: 13 additions & 1 deletion HeterogeneousCore/SonicCore/interface/SonicClientBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@
#include "HeterogeneousCore/SonicCore/interface/SonicDispatcherPseudoAsync.h"

#include <string>
#include <vector>
#include <exception>
#include <memory>
#include <optional>

enum class SonicMode { Sync = 1, Async = 2, PseudoAsync = 3 };

class RetryActionBase;

class SonicClientBase {
public:
//constructor
Expand Down Expand Up @@ -54,14 +57,23 @@ class SonicClientBase {
SonicMode mode_;
bool verbose_;
std::unique_ptr<SonicDispatcher> dispatcher_;
unsigned allowedTries_, tries_;
unsigned totalTries_;
std::optional<edm::WaitingTaskWithArenaHolder> holder_;

// Use a unique_ptr with a custom deleter to avoid incomplete type issues
struct RetryDeleter {
void operator()(RetryActionBase* ptr) const;
};

using RetryActionPtr = std::unique_ptr<RetryActionBase, RetryDeleter>;
std::vector<RetryActionPtr> retryActions_;

//for logging/debugging
std::string debugName_, clientName_, fullDebugName_;

friend class SonicDispatcher;
friend class SonicDispatcherPseudoAsync;
friend class RetryActionBase;
};

#endif
6 changes: 6 additions & 0 deletions HeterogeneousCore/SonicCore/plugins/BuildFile.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
<use name="FWCore/Framework"/>
<use name="FWCore/PluginManager"/>
<use name="FWCore/ParameterSet"/>
<use name="HeterogeneousCore/SonicCore"/>
<library file="*.cc" name="pluginHeterogeneousCoreSonicCore"/>

30 changes: 30 additions & 0 deletions HeterogeneousCore/SonicCore/plugins/RetrySameServerAction.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#include "HeterogeneousCore/SonicCore/interface/RetryActionBase.h"
#include "HeterogeneousCore/SonicCore/interface/SonicClientBase.h"

class RetrySameServerAction : public RetryActionBase {
public:
RetrySameServerAction(const edm::ParameterSet& pset, SonicClientBase* client)
: RetryActionBase(pset, client), allowedTries_(pset.getUntrackedParameter<unsigned>("allowedTries", 0)) {}

void start() override { tries_ = 0; };

protected:
void retry() override;

private:
unsigned allowedTries_, tries_;
};

void RetrySameServerAction::retry() {
++tries_;
//if max retries has not been exceeded, call evaluate again
if (tries_ < allowedTries_) {
eval();
return;
} else {
shouldRetry_ = false; // Flip flag when max retries are reached
edm::LogInfo("RetrySameServerAction") << "Max retry attempts reached. No further retries.";
}
}

DEFINE_RETRY_ACTION(RetrySameServerAction)
15 changes: 15 additions & 0 deletions HeterogeneousCore/SonicCore/src/RetryActionBase.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#include "HeterogeneousCore/SonicCore/interface/RetryActionBase.h"

// Constructor implementation
RetryActionBase::RetryActionBase(const edm::ParameterSet& conf, SonicClientBase* client)
: client_(client), shouldRetry_(true) {}

void RetryActionBase::eval() {
if (client_) {
client_->evaluate();
} else {
edm::LogError("RetryActionBase") << "Client pointer is null, cannot evaluate.";
}
}

EDM_REGISTER_PLUGINFACTORY(RetryActionFactory, "RetryActionFactory");
72 changes: 53 additions & 19 deletions HeterogeneousCore/SonicCore/src/SonicClientBase.cc
Original file line number Diff line number Diff line change
@@ -1,18 +1,33 @@
#include "HeterogeneousCore/SonicCore/interface/SonicClientBase.h"
#include "HeterogeneousCore/SonicCore/interface/RetryActionBase.h"
#include "FWCore/Utilities/interface/Exception.h"
#include "FWCore/ParameterSet/interface/allowedValues.h"

// Custom deleter implementation
void SonicClientBase::RetryDeleter::operator()(RetryActionBase* ptr) const { delete ptr; }

SonicClientBase::SonicClientBase(const edm::ParameterSet& params,
const std::string& debugName,
const std::string& clientName)
: allowedTries_(params.getUntrackedParameter<unsigned>("allowedTries", 0)),
debugName_(debugName),
clientName_(clientName),
fullDebugName_(debugName_) {
: debugName_(debugName), clientName_(clientName), fullDebugName_(debugName_) {
if (!clientName_.empty())
fullDebugName_ += ":" + clientName_;

const auto& retryPSetList = params.getParameter<std::vector<edm::ParameterSet>>("Retry");
std::string modeName(params.getParameter<std::string>("mode"));

for (const auto& retryPSet : retryPSetList) {
const std::string& actionType = retryPSet.getParameter<std::string>("retryType");

auto retryAction = RetryActionFactory::get()->create(actionType, retryPSet, this);
if (retryAction) {
//Convert to RetryActionPtr Type from raw pointer of retryAction
retryActions_.emplace_back(RetryActionPtr(retryAction.release()));
} else {
throw cms::Exception("Configuration") << "Unknown Retry type " << actionType << " for SonicClient: " << modeName;
}
}

if (modeName == "Sync")
setMode(SonicMode::Sync);
else if (modeName == "Async")
Expand Down Expand Up @@ -40,24 +55,30 @@ void SonicClientBase::start(edm::WaitingTaskWithArenaHolder holder) {
holder_ = std::move(holder);
}

void SonicClientBase::start() { tries_ = 0; }
void SonicClientBase::start() {
totalTries_ = 0;
// initialize all actions
for (const auto& action : retryActions_) {
action->start();
}
}

void SonicClientBase::finish(bool success, std::exception_ptr eptr) {
//retries are only allowed if no exception was raised
if (!success and !eptr) {
++tries_;
//if max retries has not been exceeded, call evaluate again
if (tries_ < allowedTries_) {
evaluate();
//avoid calling doneWaiting() twice
return;
}
//prepare an exception if exceeded
else {
edm::Exception ex(edm::errors::ExternalFailure);
ex << "SonicCallFailed: call failed after max " << tries_ << " tries";
eptr = make_exception_ptr(ex);
++totalTries_;
for (const auto& action : retryActions_) {
if (action->shouldRetry()) {
action->retry(); // Call retry only if shouldRetry_ is true
return;
}
}
//prepare an exception if no more retries left
edm::LogInfo("SonicClientBase") << "SonicCallFailed: call failed, no retry actions available after " << totalTries_
<< " tries.";
edm::Exception ex(edm::errors::ExternalFailure);
ex << "SonicCallFailed: call failed, no retry actions available after " << totalTries_ << " tries.";
eptr = make_exception_ptr(ex);
}
if (holder_) {
holder_->doneWaiting(eptr);
Expand All @@ -74,7 +95,20 @@ void SonicClientBase::fillBasePSetDescription(edm::ParameterSetDescription& desc
//restrict allowed values
desc.ifValue(edm::ParameterDescription<std::string>("mode", "PseudoAsync", true),
edm::allowedValues<std::string>("Sync", "Async", "PseudoAsync"));
if (allowRetry)
desc.addUntracked<unsigned>("allowedTries", 0);
if (allowRetry) {
// Defines the structure of each entry in the VPSet
edm::ParameterSetDescription retryDesc;
retryDesc.add<std::string>("retryType", "RetrySameServerAction");
retryDesc.addUntracked<unsigned>("allowedTries", 0);

// Define a default retry action
edm::ParameterSet defaultRetry;
defaultRetry.addParameter<std::string>("retryType", "RetrySameServerAction");
defaultRetry.addUntrackedParameter<unsigned>("allowedTries", 0);

// Add the VPSet with the default retry action
desc.addVPSet("Retry", retryDesc, {defaultRetry});
}
desc.add("sonicClientBase", desc);
desc.addUntracked<bool>("verbose", false);
}
2 changes: 1 addition & 1 deletion HeterogeneousCore/SonicCore/test/DummyClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class DummyClient : public SonicClient<int> {
this->output_ = this->input_ * factor_;

//simulate a failure
if (this->tries_ < fails_)
if (this->totalTries_ < fails_)
this->finish(false);
else
this->finish(true);
Expand Down
45 changes: 38 additions & 7 deletions HeterogeneousCore/SonicCore/test/sonicTest_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,19 @@

process.options.numberOfThreads = 2
process.options.numberOfStreams = 0

process.dummySync = _moduleClass(_moduleName,
input = cms.int32(1),
Client = cms.PSet(
mode = cms.string("Sync"),
factor = cms.int32(-1),
wait = cms.int32(10),
allowedTries = cms.untracked.uint32(0),
fails = cms.uint32(0),
Retry = cms.VPSet(
cms.PSet(
retryType = cms.string('RetrySameServerAction'),
allowedTries = cms.untracked.uint32(0)
)
)
),
)

Expand All @@ -37,8 +41,14 @@
mode = cms.string("PseudoAsync"),
factor = cms.int32(2),
wait = cms.int32(10),
allowedTries = cms.untracked.uint32(0),
fails = cms.uint32(0),
Retry = cms.VPSet(
cms.PSet(
retryType = cms.string('RetrySameServerAction'),
allowedTries = cms.untracked.uint32(0)
)
)

),
)

Expand All @@ -48,32 +58,53 @@
mode = cms.string("Async"),
factor = cms.int32(5),
wait = cms.int32(10),
allowedTries = cms.untracked.uint32(0),
fails = cms.uint32(0),
Retry = cms.VPSet(
cms.PSet(
retryType = cms.string('RetrySameServerAction'),
allowedTries = cms.untracked.uint32(0)
)
)
),
)

process.dummySyncRetry = process.dummySync.clone(
Client = dict(
wait = 2,
allowedTries = 2,
fails = 1,
Retry = cms.VPSet(
cms.PSet(
retryType = cms.string('RetrySameServerAction'),
allowedTries = cms.untracked.uint32(2)
)
)

)
)

process.dummyPseudoAsyncRetry = process.dummyPseudoAsync.clone(
Client = dict(
wait = 2,
allowedTries = 2,
fails = 1,
Retry = cms.VPSet(
cms.PSet(
retryType = cms.string('RetrySameServerAction'),
allowedTries = cms.untracked.uint32(2)
)
)
)
)

process.dummyAsyncRetry = process.dummyAsync.clone(
Client = dict(
wait = 2,
allowedTries = 2,
fails = 1,
Retry = cms.VPSet(
cms.PSet(
allowedTries = cms.untracked.uint32(2),
retryType = cms.string('RetrySameServerAction')
)
)
)
)

Expand Down
6 changes: 5 additions & 1 deletion HeterogeneousCore/SonicTriton/BuildFile.xml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
<use name="HeterogeneousCore/CUDAUtilities"/>
<use name="triton-inference-client"/>
<use name="protobuf"/>
<use name="catch2"/>

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be moved to test/BuildFile.xml as indicated in other comments (i.e. removed from here)

<iftool name="cuda">
<use name="cuda"/>
</iftool>

<export>
<lib name="1"/>
<lib name="1"/>
</export>

<test name="RetryActionDiffServer_test" command="RetryActionDiffServer.cc"/>

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests should be defined in HeterogeneousCore/SonicTriton/test/BuildFile.xml, not the package-level BuildFile.xml.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and the correct syntax is:

<bin file="test_RetryActionDiffServer.cc" name="TestHeterogeneousCoreSonicTritonRetryActionDiffServer">
  <use name="catch2"/>
  <use name="FWCore/ParameterSet"/>
  <use name="HeterogeneousCore/SonicTriton"/>
</bin>

33 changes: 33 additions & 0 deletions HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#ifndef HeterogeneousCore_SonicTriton_RetryActionDiffServer_h
#define HeterogeneousCore_SonicTriton_RetryActionDiffServer_h

#include "HeterogeneousCore/SonicCore/interface/RetryActionBase.h"

/**
* @class RetryActionDiffServer
* @brief A concrete implementation of RetryActionBase that attempts to retry an inference
* request on a different, user-specified Triton server.
*
* This class is designed to provide a fallback mechanism. If an initial inference
* request fails (e.g., due to server unavailability or a model-specific error),
* this action will be triggered. It reads an alternative server URL from the
* ParameterSet and instructs the TritonClient to reconnect to this new server
* for the retry attempt. This action is designed for one-time use per inference
* call; after the retry attempt, it disables itself until the next `start()` call.
*/

class RetryActionDiffServer : public RetryActionBase {
public:
RetryActionDiffServer(const edm::ParameterSet& conf, SonicClientBase* client);
~RetryActionDiffServer() override = default;

void retry() override;
void start() override;

private:
std::string alt_server_url_;
std::string alt_server_token_;
};

#endif

Loading