Skip to content

Commit 7f08435

Browse files
subhamsoni-googlecopybara-github
authored andcommitted
Add worker ID to responses and configure gRPC channels for retries and load balancing.
PiperOrigin-RevId: 827391460
1 parent c0ca36b commit 7f08435

File tree

7 files changed

+99
-7
lines changed

7 files changed

+99
-7
lines changed

plugin/xprof/profile_plugin.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -738,7 +738,6 @@ def _get_valid_hosts(
738738
for xplane_path in path.glob(file_pattern):
739739
host_name, _ = _parse_filename(xplane_path.name)
740740
if host_name:
741-
print('host_name: %s', host_name)
742741
all_xplane_files[host_name] = xplane_path
743742
except OSError as e:
744743
logger.warning('Cannot read asset directory: %s, OpError %s', run_dir, e)

plugin/xprof/protobuf/worker_service.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,6 @@ message WorkerProfileDataRequest {
4040
message WorkerProfileDataResponse {
4141
// The absolute path to the tool specific output.
4242
string output = 1;
43+
// A unique identifier for the worker that handled the request.
44+
string worker_id = 2;
4345
}

plugin/xprof/worker/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ cc_library(
1414
"@com_github_grpc_grpc//:grpc++",
1515
"@com_google_absl//absl/log",
1616
"@com_google_absl//absl/status:statusor",
17+
"@com_google_absl//absl/strings",
1718
"@org_xprof//plugin/xprof/protobuf:worker_service_cc_grpc_proto",
1819
"@org_xprof//xprof/convert:profile_processor_factory",
1920
"@org_xprof//xprof/convert:tool_options",
@@ -27,6 +28,7 @@ cc_library(
2728
deps = [
2829
":worker_service",
2930
"@com_github_grpc_grpc//:grpc++",
31+
"@com_github_grpc_grpc//:grpc_security_base",
3032
"@com_google_absl//absl/log",
3133
"@com_google_absl//absl/strings",
3234
],
@@ -48,8 +50,10 @@ cc_library(
4850
hdrs = ["stub_factory.h"],
4951
deps = [
5052
"@com_github_grpc_grpc//:grpc++",
53+
"@com_github_grpc_grpc//:grpc_security_base",
5154
"@com_google_absl//absl/base:core_headers",
5255
"@com_google_absl//absl/base:no_destructor",
56+
"@com_google_absl//absl/log",
5357
"@com_google_absl//absl/strings",
5458
"@com_google_absl//absl/synchronization",
5559
"@org_xprof//plugin/xprof/protobuf:worker_service_cc_grpc_proto",

plugin/xprof/worker/grpc_server.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License.
2121

2222
#include "absl/log/log.h"
2323
#include "absl/strings/str_cat.h"
24+
#include "grpc/grpc.h"
2425
#include "grpcpp/security/server_credentials.h"
2526
#include "grpcpp/server.h"
2627
#include "grpcpp/server_builder.h"
@@ -39,6 +40,12 @@ void InitializeGrpcServer(int port) {
3940
std::string server_address = absl::StrCat(kServerAddressPrefix, port);
4041
::grpc::ServerBuilder builder;
4142
builder.AddListeningPort(server_address, ::grpc::InsecureServerCredentials());
43+
builder.AddChannelArgument(GRPC_ARG_KEEPALIVE_TIME_MS, 20000);
44+
builder.AddChannelArgument(GRPC_ARG_KEEPALIVE_TIMEOUT_MS, 10000);
45+
builder.AddChannelArgument(GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA, 0);
46+
builder.AddChannelArgument(GRPC_ARG_HTTP2_MAX_PING_STRIKES, 0);
47+
builder.AddChannelArgument(GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH, -1);
48+
builder.AddChannelArgument(GRPC_ARG_MAX_SEND_MESSAGE_LENGTH, -1);
4249
worker_service =
4350
std::make_unique<::xprof::profiler::ProfileWorkerServiceImpl>();
4451
builder.RegisterService(worker_service.get());

plugin/xprof/worker/stub_factory.cc

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,14 @@ limitations under the License.
2525
#include "absl/base/const_init.h"
2626
#include "absl/base/no_destructor.h"
2727
#include "absl/base/thread_annotations.h"
28+
#include "absl/log/log.h"
2829
#include "absl/strings/str_split.h"
2930
#include "absl/synchronization/mutex.h"
31+
#include "grpc/grpc.h"
3032
#include "grpcpp/channel.h"
3133
#include "grpcpp/create_channel.h"
3234
#include "grpcpp/security/credentials.h"
35+
#include "grpcpp/support/channel_arguments.h"
3336
#include "plugin/xprof/protobuf/worker_service.grpc.pb.h"
3437

3538
namespace xprof {
@@ -56,6 +59,59 @@ static absl::NoDestructor<
5659
static std::atomic<size_t> gCurrentStubIndex = 0;
5760
static std::atomic<bool> gStubsInitialized = false;
5861

62+
// Creates a gRPC channel for a given worker address. This channel is
63+
// configured with a service config that enables a robust retry policy for
64+
// transient errors and sets the client-side load balancing policy to
65+
// round-robin.
66+
std::shared_ptr<::grpc::Channel> CreateWorkerChannelForAddress(
67+
const std::string& address) {
68+
grpc::ChannelArguments args;
69+
// Set a service config for the channel that enables retries.
70+
// This config will be applied to all methods of the service.
71+
// Service Config: 10-minute timeout + conservative retries + LB
72+
const char* kServiceConfigJson = R"pb(
73+
{
74+
"methodConfig":
75+
[ {
76+
"name":
77+
[ {}],
78+
"timeout": "600s",
79+
"retryPolicy": {
80+
"maxAttempts": 4,
81+
"initialBackoff": "2s",
82+
"maxBackoff": "120s",
83+
"backoffMultiplier": 2.0,
84+
"retryableStatusCodes": [
85+
"UNAVAILABLE",
86+
"RESOURCE_EXHAUSTED",
87+
"INTERNAL",
88+
"ABORTED",
89+
"NOT_FOUND"
90+
]
91+
}
92+
}],
93+
"loadBalancingConfig":
94+
[ { "round_robin": {} }]
95+
})pb";
96+
args.SetServiceConfigJSON(kServiceConfigJson);
97+
args.SetLoadBalancingPolicyName("round_robin");
98+
args.SetInt(GRPC_ARG_DNS_MIN_TIME_BETWEEN_RESOLUTIONS_MS, 5000);
99+
args.SetInt(GRPC_ARG_KEEPALIVE_TIME_MS, 20000);
100+
args.SetInt(GRPC_ARG_KEEPALIVE_TIMEOUT_MS, 10000);
101+
args.SetInt(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS, 1);
102+
args.SetInt(GRPC_ARG_ENABLE_RETRIES, 1);
103+
args.SetInt(GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH, -1);
104+
args.SetInt(GRPC_ARG_MAX_SEND_MESSAGE_LENGTH, -1);
105+
106+
// Create the channel with insecure credentials. This is acceptable because
107+
// the communication between the aggregator and workers happens within a
108+
// trusted, internal network environment.
109+
std::shared_ptr<::grpc::Channel> channel = ::grpc::CreateCustomChannel(
110+
address, ::grpc::InsecureChannelCredentials(), args); // NOLINT
111+
LOG(INFO) << "Created gRPC channel for address: " << address;
112+
return channel;
113+
}
114+
59115
void InitializeStubs(const std::string& worker_service_addresses) {
60116
absl::MutexLock lock(&gStubsMutex);
61117
if (gStubsInitialized.load(std::memory_order_acquire)) {
@@ -66,8 +122,8 @@ void InitializeStubs(const std::string& worker_service_addresses) {
66122
absl::StrSplit(worker_service_addresses, kAddressDelimiter);
67123
for (const std::string& address : addresses) {
68124
if (address.empty()) continue;
69-
std::shared_ptr<::grpc::Channel> channel = ::grpc::CreateChannel(
70-
address, ::grpc::InsecureChannelCredentials()); // NOLINT
125+
std::shared_ptr<::grpc::Channel> channel =
126+
CreateWorkerChannelForAddress(address);
71127
gStubs->push_back(XprofAnalysisWorkerService::NewStub(channel));
72128
}
73129
gStubsInitialized.store(true, std::memory_order_release);

plugin/xprof/worker/worker_service.cc

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,22 @@ limitations under the License.
1414
==============================================================================*/
1515

1616
#include "plugin/xprof/worker/worker_service.h"
17+
#include "xprof/convert/tool_options.h"
18+
19+
#if defined(_WIN32)
20+
#include <winsock2.h>
21+
#else
22+
#include <unistd.h>
23+
#endif
1724

1825
#include <string>
1926

2027
#include "absl/log/log.h"
2128
#include "absl/status/statusor.h"
29+
#include "absl/strings/str_join.h"
2230
#include "grpcpp/server_context.h"
2331
#include "grpcpp/support/status.h"
2432
#include "xprof/convert/profile_processor_factory.h"
25-
#include "xprof/convert/tool_options.h"
2633
#include "plugin/xprof/worker/grpc_utils.h"
2734

2835
namespace xprof {
@@ -32,9 +39,12 @@ ::grpc::Status ProfileWorkerServiceImpl::GetProfileData(
3239
::grpc::ServerContext* context,
3340
const ::xprof::pywrap::WorkerProfileDataRequest* request,
3441
::xprof::pywrap::WorkerProfileDataResponse* response) {
35-
LOG(INFO) << "ProfileWorkerServiceImpl::GetProfileData called with request: "
36-
<< request->DebugString();
3742
const auto& origin_request = request->origin_request();
43+
LOG(INFO) << "GetProfileData tool:" << origin_request.tool_name()
44+
<< " session:" << origin_request.session_id() << " params:{"
45+
<< absl::StrJoin(origin_request.parameters(), ",",
46+
absl::PairFormatter("="))
47+
<< "}";
3848
tensorflow::profiler::ToolOptions tool_options;
3949
for (const auto& [key, value] : origin_request.parameters()) {
4050
tool_options[key] = value;
@@ -52,8 +62,19 @@ ::grpc::Status ProfileWorkerServiceImpl::GetProfileData(
5262
return ToGrpcStatus(map_output_file.status());
5363
}
5464
response->set_output(*map_output_file);
65+
66+
// POSIX standards limit hostnames to 255 characters, so a 1024 byte
67+
// buffer is generally safe.
68+
char hostname[1024];
69+
if (gethostname(hostname, 1024) == 0) {
70+
response->set_worker_id(hostname);
71+
} else {
72+
response->set_worker_id("unknown_hostname");
73+
}
5574
LOG(INFO)
56-
<< "ProfileWorkerServiceImpl::GetProfileData finished successfully.";
75+
<< "ProfileWorkerServiceImpl::GetProfileData finished successfully by "
76+
"worker: "
77+
<< response->worker_id();
5778
return ::grpc::Status::OK;
5879
}
5980

xprof/convert/xplane_to_tools_data.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,9 @@ absl::StatusOr<std::string> CallWorkerService(const std::string& xspace_path,
465465
if (!grpc_status.ok()) {
466466
return ::xprof::profiler::ToAbslStatus(grpc_status);
467467
}
468+
LOG(INFO) << "gRPC response: tool=" << tool_name
469+
<< ", session=" << xspace_path
470+
<< ", worker_id=" << response.worker_id();
468471
return response.output();
469472
}
470473

0 commit comments

Comments
 (0)