From 352df5c5ac927c88b465a54e5424c5eecd55a21d Mon Sep 17 00:00:00 2001 From: Subham Soni Date: Mon, 3 Nov 2025 01:59:07 -0800 Subject: [PATCH] Add worker ID to responses and configure gRPC channels for retries and load balancing. PiperOrigin-RevId: 827391460 --- plugin/xprof/profile_plugin.py | 1 - plugin/xprof/protobuf/worker_service.proto | 2 + plugin/xprof/worker/BUILD | 5 ++ plugin/xprof/worker/grpc_server.cc | 7 +++ plugin/xprof/worker/grpc_utils.h | 2 + plugin/xprof/worker/stub_factory.cc | 67 ++++++++++++++++++++-- plugin/xprof/worker/worker_service.cc | 15 ++++- xprof/convert/xplane_to_tools_data.cc | 3 + 8 files changed, 94 insertions(+), 8 deletions(-) diff --git a/plugin/xprof/profile_plugin.py b/plugin/xprof/profile_plugin.py index 64edce57b..6f8f99416 100644 --- a/plugin/xprof/profile_plugin.py +++ b/plugin/xprof/profile_plugin.py @@ -738,7 +738,6 @@ def _get_valid_hosts( for xplane_path in path.glob(file_pattern): host_name, _ = _parse_filename(xplane_path.name) if host_name: - print('host_name: %s', host_name) all_xplane_files[host_name] = xplane_path except OSError as e: logger.warning('Cannot read asset directory: %s, OpError %s', run_dir, e) diff --git a/plugin/xprof/protobuf/worker_service.proto b/plugin/xprof/protobuf/worker_service.proto index cb4decab6..6f9c06fd1 100644 --- a/plugin/xprof/protobuf/worker_service.proto +++ b/plugin/xprof/protobuf/worker_service.proto @@ -40,4 +40,6 @@ message WorkerProfileDataRequest { message WorkerProfileDataResponse { // The absolute path to the tool specific output. string output = 1; + // A unique identifier for the worker that handled the request. + string worker_id = 2; } diff --git a/plugin/xprof/worker/BUILD b/plugin/xprof/worker/BUILD index 456de9d69..714bdfc80 100644 --- a/plugin/xprof/worker/BUILD +++ b/plugin/xprof/worker/BUILD @@ -14,9 +14,11 @@ cc_library( "@com_github_grpc_grpc//:grpc++", "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@org_xprof//plugin/xprof/protobuf:worker_service_cc_grpc_proto", "@org_xprof//xprof/convert:profile_processor_factory", "@org_xprof//xprof/convert:tool_options", + "@tsl//tsl/platform:platform_port", ], ) @@ -27,6 +29,7 @@ cc_library( deps = [ ":worker_service", "@com_github_grpc_grpc//:grpc++", + "@com_github_grpc_grpc//:grpc_security_base", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", ], @@ -48,8 +51,10 @@ cc_library( hdrs = ["stub_factory.h"], deps = [ "@com_github_grpc_grpc//:grpc++", + "@com_github_grpc_grpc//:grpc_security_base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@org_xprof//plugin/xprof/protobuf:worker_service_cc_grpc_proto", diff --git a/plugin/xprof/worker/grpc_server.cc b/plugin/xprof/worker/grpc_server.cc index 76ea5bcf3..0345fcd8d 100644 --- a/plugin/xprof/worker/grpc_server.cc +++ b/plugin/xprof/worker/grpc_server.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/strings/str_cat.h" +#include "grpc/grpc.h" #include "grpcpp/security/server_credentials.h" #include "grpcpp/server.h" #include "grpcpp/server_builder.h" @@ -39,6 +40,12 @@ void InitializeGrpcServer(int port) { std::string server_address = absl::StrCat(kServerAddressPrefix, port); ::grpc::ServerBuilder builder; builder.AddListeningPort(server_address, ::grpc::InsecureServerCredentials()); + builder.AddChannelArgument(GRPC_ARG_KEEPALIVE_TIME_MS, 20000); + builder.AddChannelArgument(GRPC_ARG_KEEPALIVE_TIMEOUT_MS, 10000); + builder.AddChannelArgument(GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA, 0); + builder.AddChannelArgument(GRPC_ARG_HTTP2_MAX_PING_STRIKES, 0); + builder.AddChannelArgument(GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH, -1); + builder.AddChannelArgument(GRPC_ARG_MAX_SEND_MESSAGE_LENGTH, -1); worker_service = std::make_unique<::xprof::profiler::ProfileWorkerServiceImpl>(); builder.RegisterService(worker_service.get()); diff --git a/plugin/xprof/worker/grpc_utils.h b/plugin/xprof/worker/grpc_utils.h index 2707d28d9..18551bc36 100644 --- a/plugin/xprof/worker/grpc_utils.h +++ b/plugin/xprof/worker/grpc_utils.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef THIRD_PARTY_XPROF_PLUGIN_TENSORBOARD_PLUGIN_PROFILE_WORKER_GRPC_UTILS_H_ #define THIRD_PARTY_XPROF_PLUGIN_TENSORBOARD_PLUGIN_PROFILE_WORKER_GRPC_UTILS_H_ +#include + #include "absl/status/status.h" #include "grpcpp/support/status.h" diff --git a/plugin/xprof/worker/stub_factory.cc b/plugin/xprof/worker/stub_factory.cc index 89611a9f9..e61a457af 100644 --- a/plugin/xprof/worker/stub_factory.cc +++ b/plugin/xprof/worker/stub_factory.cc @@ -25,11 +25,15 @@ limitations under the License. #include "absl/base/const_init.h" #include "absl/base/no_destructor.h" #include "absl/base/thread_annotations.h" +#include "absl/log/log.h" #include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "grpc/grpc.h" #include "grpcpp/channel.h" #include "grpcpp/create_channel.h" #include "grpcpp/security/credentials.h" +#include "grpcpp/support/channel_arguments.h" #include "plugin/xprof/protobuf/worker_service.grpc.pb.h" namespace xprof { @@ -39,6 +43,35 @@ using xprof::pywrap::grpc::XprofAnalysisWorkerService; constexpr char kAddressDelimiter = ','; +// Service config for the gRPC channel. This config will be applied to all +// methods of the service. It enables a robust retry policy for transient errors +// (UNAVAILABLE, RESOURCE_EXHAUSTED, etc.), sets a 10-minute timeout, and +// configures client-side round-robin load balancing. +constexpr char kServiceConfigJson[] = R"pb( + { + "methodConfig": + [ { + "name": + [ {}], + "timeout": "600s", + "retryPolicy": { + "maxAttempts": 4, + "initialBackoff": "2s", + "maxBackoff": "120s", + "backoffMultiplier": 2.0, + "retryableStatusCodes": [ + "UNAVAILABLE", + "RESOURCE_EXHAUSTED", + "INTERNAL", + "ABORTED", + "NOT_FOUND" + ] + } + }], + "loadBalancingConfig": + [ { "round_robin": {} }] + })pb"; + ABSL_CONST_INIT absl::Mutex gStubsMutex(absl::kConstInit); // gStubs holds the gRPC stubs for the worker services. // It is a vector of unique_ptrs to ensure that the stubs are properly @@ -56,18 +89,44 @@ static absl::NoDestructor< static std::atomic gCurrentStubIndex = 0; static std::atomic gStubsInitialized = false; +// Creates a gRPC channel for a given worker address. This channel is +// configured with a service config that enables a robust retry policy for +// transient errors and sets the client-side load balancing policy to +// round-robin. +static std::shared_ptr<::grpc::Channel> CreateWorkerChannelForAddress( + absl::string_view address) { + grpc::ChannelArguments args; + args.SetServiceConfigJSON(kServiceConfigJson); + args.SetLoadBalancingPolicyName("round_robin"); + args.SetInt(GRPC_ARG_DNS_MIN_TIME_BETWEEN_RESOLUTIONS_MS, 5000); + args.SetInt(GRPC_ARG_KEEPALIVE_TIME_MS, 20000); + args.SetInt(GRPC_ARG_KEEPALIVE_TIMEOUT_MS, 10000); + args.SetInt(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS, 1); + args.SetInt(GRPC_ARG_ENABLE_RETRIES, 1); + args.SetInt(GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH, -1); + args.SetInt(GRPC_ARG_MAX_SEND_MESSAGE_LENGTH, -1); + + // Create the channel with insecure credentials. This is acceptable because + // the communication between the aggregator and workers happens within a + // trusted, internal network environment. + std::shared_ptr<::grpc::Channel> channel = ::grpc::CreateCustomChannel( + std::string(address), ::grpc::InsecureChannelCredentials(), args); // NOLINT + LOG(INFO) << "Created gRPC channel for address: " << address; + return channel; +} + void InitializeStubs(const std::string& worker_service_addresses) { absl::MutexLock lock(&gStubsMutex); if (gStubsInitialized.load(std::memory_order_acquire)) { // Already initialized. return; } - std::vector addresses = + std::vector addresses = absl::StrSplit(worker_service_addresses, kAddressDelimiter); - for (const std::string& address : addresses) { + for (absl::string_view address : addresses) { if (address.empty()) continue; - std::shared_ptr<::grpc::Channel> channel = ::grpc::CreateChannel( - address, ::grpc::InsecureChannelCredentials()); // NOLINT + std::shared_ptr<::grpc::Channel> channel = + CreateWorkerChannelForAddress(address); gStubs->push_back(XprofAnalysisWorkerService::NewStub(channel)); } gStubsInitialized.store(true, std::memory_order_release); diff --git a/plugin/xprof/worker/worker_service.cc b/plugin/xprof/worker/worker_service.cc index c1f589da7..8a50ca7ce 100644 --- a/plugin/xprof/worker/worker_service.cc +++ b/plugin/xprof/worker/worker_service.cc @@ -19,11 +19,13 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" +#include "absl/strings/str_join.h" #include "grpcpp/server_context.h" #include "grpcpp/support/status.h" #include "xprof/convert/profile_processor_factory.h" #include "xprof/convert/tool_options.h" #include "plugin/xprof/worker/grpc_utils.h" +#include "tsl/platform/host_info.h" namespace xprof { namespace profiler { @@ -32,9 +34,12 @@ ::grpc::Status ProfileWorkerServiceImpl::GetProfileData( ::grpc::ServerContext* context, const ::xprof::pywrap::WorkerProfileDataRequest* request, ::xprof::pywrap::WorkerProfileDataResponse* response) { - LOG(INFO) << "ProfileWorkerServiceImpl::GetProfileData called with request: " - << request->DebugString(); const auto& origin_request = request->origin_request(); + LOG(INFO) << "GetProfileData tool:" << origin_request.tool_name() + << " session:" << origin_request.session_id() << " params:{" + << absl::StrJoin(origin_request.parameters(), ",", + absl::PairFormatter("=")) + << "}"; tensorflow::profiler::ToolOptions tool_options; for (const auto& [key, value] : origin_request.parameters()) { tool_options[key] = value; @@ -52,8 +57,12 @@ ::grpc::Status ProfileWorkerServiceImpl::GetProfileData( return ToGrpcStatus(map_output_file.status()); } response->set_output(*map_output_file); + response->set_worker_id(tsl::port::Hostname()); + LOG(INFO) - << "ProfileWorkerServiceImpl::GetProfileData finished successfully."; + << "ProfileWorkerServiceImpl::GetProfileData finished successfully by " + "worker: " + << response->worker_id(); return ::grpc::Status::OK; } diff --git a/xprof/convert/xplane_to_tools_data.cc b/xprof/convert/xplane_to_tools_data.cc index b165f91cf..e4264b5ac 100644 --- a/xprof/convert/xplane_to_tools_data.cc +++ b/xprof/convert/xplane_to_tools_data.cc @@ -465,6 +465,9 @@ absl::StatusOr CallWorkerService(const std::string& xspace_path, if (!grpc_status.ok()) { return ::xprof::profiler::ToAbslStatus(grpc_status); } + LOG(INFO) << "gRPC response: tool=" << tool_name + << ", session=" << xspace_path + << ", worker_id=" << response.worker_id(); return response.output(); }