From ca5c46538ad48dce8028f524f7c5b7d4c4b1db9d Mon Sep 17 00:00:00 2001 From: "yuyongqiang.yyq" Date: Mon, 8 Sep 2025 03:15:47 +0000 Subject: [PATCH 1/4] add session negotiating before running --- MODULE.bazel.lock | 275 ++++++++++++++++++++ engine/framework/BUILD.bazel | 12 + engine/framework/session.cc | 133 ++++++++-- engine/framework/session.h | 33 +-- engine/framework/session_negotiation.proto | 38 +++ engine/operator/in.cc | 4 +- engine/operator/join.cc | 4 +- engine/services/engine_service_impl_test.cc | 102 ++++++++ 8 files changed, 568 insertions(+), 33 deletions(-) create mode 100644 engine/framework/session_negotiation.proto diff --git a/MODULE.bazel.lock b/MODULE.bazel.lock index 3dac170e..c800ce7d 100644 --- a/MODULE.bazel.lock +++ b/MODULE.bazel.lock @@ -886,6 +886,137 @@ }, "selectedYankedVersions": {}, "moduleExtensions": { + "//bazel:defs.bzl%non_module_dependencies": { + "general": { + "bzlTransitiveDigest": "klEmpjNC1WHT/ZVR9b9FjhZRsul/OS1ZYICoaOjJTZU=", + "usagesDigest": "vBXKVqtlSxvdHzob/mZXuLxycd10yCCC4x/AizZPBNU=", + "recordedFileInputs": {}, + "recordedDirentsInputs": {}, + "envVariables": {}, + "generatedRepoSpecs": { + "com_github_gperftools_gperftools": { + "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", + "ruleClassName": "http_archive", + "attributes": { + "type": "tar.gz", + "strip_prefix": "gperftools-2.15", + "sha256": "c69fef855628c81ef56f12e3c58f2b7ce1f326c0a1fe783e5cae0b88cbbe9a80", + "urls": [ + "https://github.com/gperftools/gperftools/releases/download/gperftools-2.15/gperftools-2.15.tar.gz" + ], + "build_file": "@@//engine/bazel:gperftools.BUILD" + } + }, + "io_opentelemetry_cpp": { + "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", + "ruleClassName": "http_archive", + "attributes": { + "urls": [ + "https://codeload.github.com/open-telemetry/opentelemetry-cpp/tar.gz/refs/tags/v1.3.0" + ], + "sha256": "6a4c43b9c9f753841ebc0fe2717325271f02e2a1d5ddd0b52735c35243629ab3", + "strip_prefix": "opentelemetry-cpp-1.3.0" + } + }, + "com_mysql": { + "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", + "ruleClassName": "http_archive", + "attributes": { + "urls": [ + "https://github.com/mysql/mysql-server/archive/refs/tags/mysql-8.0.30.tar.gz" + ], + "patch_args": [ + "-p1" + ], + "patches": [ + "@@//engine/bazel:patches/mysql.patch" + ], + "sha256": "e76636197f9cb764940ad8d800644841771def046ce6ae75c346181d5cdd879a", + "strip_prefix": "mysql-server-mysql-8.0.30", + "build_file": "@@//engine/bazel:mysql.BUILD" + } + }, + "org_pocoproject_poco": { + "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", + "ruleClassName": "http_archive", + "attributes": { + "urls": [ + "https://github.com/pocoproject/poco/archive/refs/tags/poco-1.12.2-release.tar.gz" + ], + "strip_prefix": "poco-poco-1.12.2-release", + "sha256": "30442ccb097a0074133f699213a59d6f8c77db5b2c98a7c1ad9c5eeb3a2b06f3", + "build_file": "@@//engine/bazel:poco.BUILD", + "patch_args": [ + "-p1" + ], + "patches": [ + "@@//engine/bazel:patches/poco.patch" + ] + } + }, + "org_sqlite": { + "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", + "ruleClassName": "http_archive", + "attributes": { + "urls": [ + "https://www.sqlite.org/2020/sqlite-amalgamation-3320200.zip" + ], + "sha256": "7e1ebd182a61682f94b67df24c3e6563ace182126139315b659f25511e2d0b5d", + "strip_prefix": "sqlite-amalgamation-3320200", + "build_file": "@@//engine/bazel:sqlite3.BUILD" + } + }, + "com_github_duckdb": { + "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", + "ruleClassName": "http_archive", + "attributes": { + "urls": [ + "https://github.com/duckdb/duckdb/archive/refs/tags/v1.0.0.tar.gz" + ], + "patch_args": [ + "-p1" + ], + "patches": [ + "@@//engine/bazel:patches/duckdb.patch" + ], + "sha256": "04e472e646f5cadd0a3f877a143610674b0d2bcf9f4102203ac3c3d02f1c5f26", + "strip_prefix": "duckdb-1.0.0", + "build_file": "@@//engine/bazel:duckdb.BUILD" + } + }, + "org_postgres": { + "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", + "ruleClassName": "http_archive", + "attributes": { + "urls": [ + "https://ftp.postgresql.org/pub/source/v15.2/postgresql-15.2.tar.gz" + ], + "sha256": "eccd208f3e7412ad7bc4c648ecc87e0aa514e02c24a48f71bf9e46910bf284ca", + "strip_prefix": "postgresql-15.2", + "build_file": "@@//engine/bazel:postgres.BUILD" + } + }, + "rules_proto_grpc": { + "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", + "ruleClassName": "http_archive", + "attributes": { + "sha256": "2a0860a336ae836b54671cbbe0710eec17c64ef70c4c5a88ccfd47ea6e3739bd", + "strip_prefix": "rules_proto_grpc-4.6.0", + "urls": [ + "https://github.com/rules-proto-grpc/rules_proto_grpc/releases/download/4.6.0/rules_proto_grpc-4.6.0.tar.gz" + ] + } + } + }, + "recordedRepoMappingEntries": [ + [ + "", + "bazel_tools", + "bazel_tools" + ] + ] + } + }, "@@apple_support~//crosstool:setup.bzl%apple_cc_configure_extension": { "general": { "bzlTransitiveDigest": "7ii+gFxWSxHhQPrBxfMEHhtrGvHmBTvsh+KOyGunP/s=", @@ -1416,6 +1547,106 @@ "recordedRepoMappingEntries": [] } }, + "@@grpc~//bazel:grpc_deps.bzl%grpc_repo_deps_ext": { + "general": { + "bzlTransitiveDigest": "5TfzYlp8Y+UKkVz5RlvJ+WoMjs6d3+Y3sHhn1c7I86o=", + "usagesDigest": "1eMIqi5dORV/U/qhnyfh+vm4MzWYt3L+ru89nRp4d/A=", + "recordedFileInputs": {}, + "recordedDirentsInputs": {}, + "envVariables": {}, + "generatedRepoSpecs": { + "io_opencensus_cpp": { + "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", + "ruleClassName": "http_archive", + "attributes": { + "sha256": "46b3b5812c150a21bacf860c2f76fc42b89773ed77ee954c32adeb8593aa2a8e", + "strip_prefix": "opencensus-cpp-5501a1a255805e0be83a41348bb5f2630d5ed6b3", + "urls": [ + "https://storage.googleapis.com/grpc-bazel-mirror/github.com/census-instrumentation/opencensus-cpp/archive/5501a1a255805e0be83a41348bb5f2630d5ed6b3.tar.gz", + "https://github.com/census-instrumentation/opencensus-cpp/archive/5501a1a255805e0be83a41348bb5f2630d5ed6b3.tar.gz" + ] + } + }, + "envoy_api": { + "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", + "ruleClassName": "http_archive", + "attributes": { + "sha256": "e525a6fb6e6ed3eef1eec6bef3da9b5708e471f0f9335a7604df14a4b386231e", + "strip_prefix": "data-plane-api-f8b75d1efa92bbf534596a013d9ca5873f79dd30", + "urls": [ + "https://storage.googleapis.com/grpc-bazel-mirror/github.com/envoyproxy/data-plane-api/archive/f8b75d1efa92bbf534596a013d9ca5873f79dd30.tar.gz", + "https://github.com/envoyproxy/data-plane-api/archive/f8b75d1efa92bbf534596a013d9ca5873f79dd30.tar.gz" + ] + } + }, + "opencensus_proto": { + "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", + "ruleClassName": "http_archive", + "attributes": { + "sha256": "b7e13f0b4259e80c3070b583c2f39e53153085a6918718b1c710caf7037572b0", + "strip_prefix": "opencensus-proto-0.3.0/src", + "urls": [ + "https://storage.googleapis.com/grpc-bazel-mirror/github.com/census-instrumentation/opencensus-proto/archive/v0.3.0.tar.gz", + "https://github.com/census-instrumentation/opencensus-proto/archive/v0.3.0.tar.gz" + ] + } + }, + "com_envoyproxy_protoc_gen_validate": { + "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", + "ruleClassName": "http_archive", + "attributes": { + "strip_prefix": "protoc-gen-validate-4694024279bdac52b77e22dc87808bd0fd732b69", + "sha256": "1e490b98005664d149b379a9529a6aa05932b8a11b76b4cd86f3d22d76346f47", + "urls": [ + "https://github.com/envoyproxy/protoc-gen-validate/archive/4694024279bdac52b77e22dc87808bd0fd732b69.tar.gz" + ], + "patches": [ + "@@grpc~//third_party:protoc-gen-validate.patch" + ], + "patch_args": [ + "-p1" + ] + } + }, + "com_github_cncf_xds": { + "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", + "ruleClassName": "http_archive", + "attributes": { + "sha256": "dc305e20c9fa80822322271b50aa2ffa917bf4fd3973bcec52bfc28dc32c5927", + "strip_prefix": "xds-3a472e524827f72d1ad621c4983dd5af54c46776", + "urls": [ + "https://storage.googleapis.com/grpc-bazel-mirror/github.com/cncf/xds/archive/3a472e524827f72d1ad621c4983dd5af54c46776.tar.gz", + "https://github.com/cncf/xds/archive/3a472e524827f72d1ad621c4983dd5af54c46776.tar.gz" + ] + } + }, + "google_cloud_cpp": { + "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", + "ruleClassName": "http_archive", + "attributes": { + "sha256": "7ca7f583b60d2aa1274411fed3b9fb3887119b2e84244bb3fc69ea1db819e4e5", + "strip_prefix": "google-cloud-cpp-2.16.0", + "urls": [ + "https://storage.googleapis.com/grpc-bazel-mirror/github.com/googleapis/google-cloud-cpp/archive/refs/tags/v2.16.0.tar.gz", + "https://github.com/googleapis/google-cloud-cpp/archive/refs/tags/v2.16.0.tar.gz" + ] + } + } + }, + "recordedRepoMappingEntries": [ + [ + "grpc~", + "bazel_tools", + "bazel_tools" + ], + [ + "grpc~", + "com_github_grpc_grpc", + "grpc~" + ] + ] + } + }, "@@rules_buf~//buf:extensions.bzl%ext": { "general": { "bzlTransitiveDigest": "gmPmM7QT5Jez2VVFcwbbMf/QWSRag+nJ1elFJFFTcn0=", @@ -5152,6 +5383,50 @@ ] ] } + }, + "@@spulib~//bazel:defs.bzl%non_module_dependencies": { + "general": { + "bzlTransitiveDigest": "JT8ZLEUdrYXN19gijrHtztFq/cEAhJlRlNjhtQUlDIE=", + "usagesDigest": "1GNetOfcHHCHzNnKy7W83BJODSj5fGSOhsu29kEblUk=", + "recordedFileInputs": {}, + "recordedDirentsInputs": {}, + "envVariables": {}, + "generatedRepoSpecs": { + "xtensor": { + "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", + "ruleClassName": "http_archive", + "attributes": { + "sha256": "32d5d9fd23998c57e746c375a544edf544b74f0a18ad6bc3c38cbba968d5e6c7", + "strip_prefix": "xtensor-0.25.0", + "build_file": "@@spulib~//bazel:xtensor.BUILD", + "type": "tar.gz", + "urls": [ + "https://github.com/xtensor-stack/xtensor/archive/refs/tags/0.25.0.tar.gz" + ] + } + }, + "xtl": { + "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", + "ruleClassName": "http_archive", + "attributes": { + "sha256": "44fb99fbf5e56af5c43619fc8c29aa58e5fad18f3ba6e7d9c55c111b62df1fbb", + "strip_prefix": "xtl-0.7.7", + "build_file": "@@spulib~//bazel:xtl.BUILD", + "type": "tar.gz", + "urls": [ + "https://github.com/xtensor-stack/xtl/archive/refs/tags/0.7.7.tar.gz" + ] + } + } + }, + "recordedRepoMappingEntries": [ + [ + "spulib~", + "bazel_tools", + "bazel_tools" + ] + ] + } } } } diff --git a/engine/framework/BUILD.bazel b/engine/framework/BUILD.bazel index 9c265e0b..d0cf8a1b 100644 --- a/engine/framework/BUILD.bazel +++ b/engine/framework/BUILD.bazel @@ -73,6 +73,7 @@ cc_library( hdrs = ["session.h"], deps = [ ":party_info", + ":session_negotiation_cc_proto", ":tensor_table", "//api:common_cc_proto", "//api:engine_cc_proto", @@ -141,3 +142,14 @@ scql_cc_test( "//engine/operator:test_util", ], ) + +proto_library( + name = "session_negotiation_proto", + srcs = ["session_negotiation.proto"], + deps = ["//api:engine_proto"], +) + +cc_proto_library( + name = "session_negotiation_cc_proto", + deps = [":session_negotiation_proto"], +) diff --git a/engine/framework/session.cc b/engine/framework/session.cc index 0c3e7eae..fabe3fa1 100644 --- a/engine/framework/session.cc +++ b/engine/framework/session.cc @@ -14,11 +14,15 @@ #include "engine/framework/session.h" +#include #include #include +#include +#include #include "algorithm" #include "arrow/visit_array_inline.h" +#include "google/protobuf/util/json_util.h" #include "libspu/core/config.h" #include "libspu/mpc/factory.h" #include "openssl/sha.h" @@ -32,6 +36,8 @@ #include "engine/util/prometheus_monitor.h" #include "engine/util/psi/detail_logger.h" +#include "engine/framework/session_negotiation.pb.h" + DEFINE_string(tmp_file_path, "/tmp", "dir to out tmp files"); DEFINE_uint64( streaming_row_num_threshold, 30 * 1000 * 1000, @@ -148,21 +154,115 @@ Session::Session(const SessionOptions& session_opt, util::PrometheusMonitor::GetInstance()->IncSessionNumberTotal(); // default not streaming - streaming_options_.batched = false; - streaming_options_.streaming_row_num_threshold = + session_opt_.streaming_options.batched = false; + session_opt_.streaming_options.streaming_row_num_threshold = FLAGS_streaming_row_num_threshold; - streaming_options_.batch_row_num = FLAGS_batch_row_num; + session_opt_.streaming_options.batch_row_num = FLAGS_batch_row_num; + // negotiate session options + Negotiate(); } Session::~Session() { util::PrometheusMonitor::GetInstance()->DecSessionNumberTotal(); - if (streaming_options_.batched) { + if (session_opt_.streaming_options.batched) { std::error_code ec; - std::filesystem::remove_all(streaming_options_.dump_file_dir, ec); + std::filesystem::remove_all(session_opt_.streaming_options.dump_file_dir, + ec); if (ec.value() != 0) { SPDLOG_WARN("can not remove tmp dir: {}, msg: {}", - streaming_options_.dump_file_dir.string(), ec.message()); + session_opt_.streaming_options.dump_file_dir.string(), + ec.message()); + } + } +} + +void Session::Negotiate() { + negotiate::NegotiationOptions options; + // fill options from flag + auto* psi_options = options.mutable_psi_options(); + psi_options->set_psi_curve_type(session_opt_.psi_config.psi_curve_type); + psi_options->set_unbalance_psi_larger_party_rows_count_threshold( + session_opt_.psi_config.unbalance_psi_larger_party_rows_count_threshold); + psi_options->set_unbalance_psi_ratio_threshold( + session_opt_.psi_config.unbalance_psi_ratio_threshold); + psi_options->set_use_rr22_low_comm_mode( + session_opt_.psi_config.low_comm_mode); + auto* streaming_options = options.mutable_streaming_options(); + streaming_options->set_streaming_row_num_threshold( + session_opt_.streaming_options.streaming_row_num_threshold); + streaming_options->set_batch_row_num( + session_opt_.streaming_options.batch_row_num); + // negotiation + std::vector negotiation_options_v; + if (lctx_->WorldSize() > 1) { + std::string option_str; + auto status = + google::protobuf::util::MessageToJsonString(options, &option_str); + YACL_ENFORCE(status.ok(), status.message()); + yacl::ByteContainerView buffer(option_str); + auto buffers = yacl::link::AllGather(lctx_, buffer, "negotiate"); + // parse + for (const auto& value : buffers) { + negotiate::NegotiationOptions tmp_options; + status = google::protobuf::util::JsonStringToMessage(value, &tmp_options); + YACL_ENFORCE(status.ok(), status.message()); + negotiation_options_v.push_back(std::move(tmp_options)); } + } else { + negotiation_options_v.push_back(options); + } + + // resolution + session_opt_.psi_config.low_comm_mode = true; + std::vector curve_types; + for (auto& negotiation_options : negotiation_options_v) { + // use fast mode if any party set use_rr22_low_comm_mode to false + if (!negotiation_options.psi_options().use_rr22_low_comm_mode()) { + session_opt_.psi_config.low_comm_mode = false; + } + SPDLOG_INFO( + "session_opt_: {}, {}", + session_opt_.psi_config.unbalance_psi_larger_party_rows_count_threshold, + negotiation_options.psi_options() + .unbalance_psi_larger_party_rows_count_threshold()); + // choose max unbalance_psi_larger_party_rows_count_threshold + if (session_opt_.psi_config + .unbalance_psi_larger_party_rows_count_threshold < + negotiation_options.psi_options() + .unbalance_psi_larger_party_rows_count_threshold()) { + session_opt_.psi_config.unbalance_psi_larger_party_rows_count_threshold = + negotiation_options.psi_options() + .unbalance_psi_larger_party_rows_count_threshold(); + } + // choose max unbalance_psi_ratio_threshold + if (session_opt_.psi_config.unbalance_psi_ratio_threshold < + negotiation_options.psi_options().unbalance_psi_ratio_threshold()) { + session_opt_.psi_config.unbalance_psi_ratio_threshold = + negotiation_options.psi_options().unbalance_psi_ratio_threshold(); + } + curve_types.push_back(static_cast( + negotiation_options.psi_options().psi_curve_type())); + // choose min streaming_row_num_threshold + if (static_cast( + session_opt_.streaming_options.streaming_row_num_threshold) > + negotiation_options.streaming_options().streaming_row_num_threshold()) { + session_opt_.streaming_options.streaming_row_num_threshold = + negotiation_options.streaming_options().streaming_row_num_threshold(); + } + // choose min batch_row_num + if (static_cast(session_opt_.streaming_options.batch_row_num) > + negotiation_options.streaming_options().batch_row_num()) { + session_opt_.streaming_options.batch_row_num = + negotiation_options.streaming_options().batch_row_num(); + } + } + // if the curve type between parties differ, it is considered invalid. + if (!std::all_of(curve_types.begin(), curve_types.end(), [&](int curve_type) { + return curve_type == curve_types[0]; + })) { + session_opt_.psi_config.psi_curve_type = psi::CurveType::CURVE_INVALID_TYPE; + } else { + session_opt_.psi_config.psi_curve_type = curve_types[0]; } } @@ -203,21 +303,22 @@ void Session::MergeDeviceSymbolsFrom(const spu::device::SymbolTable& other) { } void Session::EnableStreamingBatched() { - streaming_options_.batched = true; - size_t data[2] = {streaming_options_.batch_row_num, - streaming_options_.streaming_row_num_threshold}; + session_opt_.streaming_options.batched = true; + size_t data[2] = {session_opt_.streaming_options.batch_row_num, + session_opt_.streaming_options.streaming_row_num_threshold}; // get checksum from other parties auto bufs = yacl::link::AllGather( GetLink(), yacl::ByteContainerView(data, 2 * sizeof(size_t)), "streaming_options"); for (const auto& buf : bufs) { - streaming_options_.batch_row_num = - std::min(streaming_options_.batch_row_num, buf.data()[0]); - streaming_options_.streaming_row_num_threshold = std::min( - streaming_options_.streaming_row_num_threshold, buf.data()[1]); + session_opt_.streaming_options.batch_row_num = std::min( + session_opt_.streaming_options.batch_row_num, buf.data()[0]); + session_opt_.streaming_options.streaming_row_num_threshold = + std::min(session_opt_.streaming_options.streaming_row_num_threshold, + buf.data()[1]); } - if (streaming_options_.dump_file_dir.empty()) { - streaming_options_.dump_file_dir = + if (session_opt_.streaming_options.dump_file_dir.empty()) { + session_opt_.streaming_options.dump_file_dir = util::CreateDirWithRandSuffix(FLAGS_tmp_file_path, id_); } } @@ -394,7 +495,7 @@ void Session::UpdateRefName(const std::vector& input_ref_names, continue; } auto iter = tensor_ref_nums_.find(name); - if (!streaming_options_.batched) { + if (!session_opt_.streaming_options.batched) { YACL_ENFORCE(iter == tensor_ref_nums_.end(), "ref num of {} was set before created", name); } diff --git a/engine/framework/session.h b/engine/framework/session.h index 581ced42..156cd61b 100644 --- a/engine/framework/session.h +++ b/engine/framework/session.h @@ -72,27 +72,29 @@ struct PsiConfig { int64_t unbalance_psi_ratio_threshold = 0; int64_t unbalance_psi_larger_party_rows_count_threshold = 0; int32_t psi_curve_type = 0; + bool low_comm_mode = false; }; struct LogConfig { bool enable_session_logger_separation = false; }; -struct SessionOptions { - util::LogOptions log_options; - LinkConfig link_config; - PsiConfig psi_config; - LogConfig log_config; -}; - struct StreamingOptions { std::filesystem::path dump_file_dir; - bool batched; + bool batched = false; // if row num is less than this threshold, close streaming mode and keep all // data in memory - size_t streaming_row_num_threshold; + size_t streaming_row_num_threshold = 0; // if working in streaming mode, max row num in one batch - size_t batch_row_num; + size_t batch_row_num = 0; +}; + +struct SessionOptions { + util::LogOptions log_options; + LinkConfig link_config; + PsiConfig psi_config; + LogConfig log_config; + StreamingOptions streaming_options; }; /// @brief Session holds everything needed to run the execution plan. @@ -230,18 +232,21 @@ class Session { const SessionOptions& GetSessionOptions() const { return session_opt_; } - StreamingOptions GetStreamingOptions() { return streaming_options_; } + StreamingOptions GetStreamingOptions() const { + return session_opt_.streaming_options; + } void SetStreamingOptions(const StreamingOptions& streaming_options) { - streaming_options_ = streaming_options; + session_opt_.streaming_options = streaming_options; } void EnableStreamingBatched(); + void Negotiate(); private: void InitLink(); bool ValidateSPUContext(); const std::string id_; - const SessionOptions session_opt_; + SessionOptions session_opt_; const std::string time_zone_; PartyInfo parties_; std::atomic state_; @@ -281,8 +286,6 @@ class Session { std::string current_node_name_; std::chrono::time_point node_start_time_; std::shared_ptr progres_stats_; - // for streaming - StreamingOptions streaming_options_; }; std::shared_ptr ActiveLogger(const Session* session); diff --git a/engine/framework/session_negotiation.proto b/engine/framework/session_negotiation.proto new file mode 100644 index 00000000..0f99e074 --- /dev/null +++ b/engine/framework/session_negotiation.proto @@ -0,0 +1,38 @@ +// +// Copyright 2025 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +syntax = "proto3"; + +package scql.engine.negotiate; + +import "api/engine.proto"; + +message StreamingOptions { + int64 streaming_row_num_threshold = 1; + int64 batch_row_num = 2; +} + +message PsiOptions { + int64 unbalance_psi_ratio_threshold = 1; + int64 unbalance_psi_larger_party_rows_count_threshold = 2; + int32 psi_curve_type = 3; + bool use_rr22_low_comm_mode = 4; +} + +message NegotiationOptions { + StreamingOptions streaming_options = 1; + PsiOptions psi_options = 2; +} diff --git a/engine/operator/in.cc b/engine/operator/in.cc index ad1adfa0..2ed2e9f7 100644 --- a/engine/operator/in.cc +++ b/engine/operator/in.cc @@ -513,7 +513,9 @@ void In::Rr22PsiIn(ExecContext* ctx) { provider.CleanBucket(bucket_idx); }; psi::rr22::Rr22Runner runner( - psi_link, psi::rr22::GenerateRr22PsiOptions(FLAGS_use_rr22_low_comm_mode), + psi_link, + psi::rr22::GenerateRr22PsiOptions( + ctx->GetSession()->GetSessionOptions().psi_config.low_comm_mode), bucket_num, false, pre_f, post_f); // reveal party as receiver runner.AsyncRun(0, reveal_to != my_party_code); diff --git a/engine/operator/join.cc b/engine/operator/join.cc index d3d6fb08..1f385140 100644 --- a/engine/operator/join.cc +++ b/engine/operator/join.cc @@ -350,7 +350,9 @@ void Join::Rr22PsiJoin(ExecContext* ctx) { provider.CleanBucket(bucket_idx); }; psi::rr22::Rr22Runner runner( - psi_link, psi::rr22::GenerateRr22PsiOptions(FLAGS_use_rr22_low_comm_mode), + psi_link, + psi::rr22::GenerateRr22PsiOptions( + ctx->GetSession()->GetSessionOptions().psi_config.low_comm_mode), bucket_num, target_rank == yacl::link::kAllRank, pre_f, post_f); bool is_sender = is_left; // receiver should be the one who has more data diff --git a/engine/services/engine_service_impl_test.cc b/engine/services/engine_service_impl_test.cc index 55541c29..6baa71fb 100644 --- a/engine/services/engine_service_impl_test.cc +++ b/engine/services/engine_service_impl_test.cc @@ -34,6 +34,7 @@ #include "engine/operator/publish.h" #include "engine/operator/run_sql.h" #include "engine/operator/test_util.h" +#include "engine/util/concurrent_queue.h" #include "api/status_code.pb.h" #include "engine/services/mock_report_service.pb.h" @@ -527,6 +528,107 @@ TEST_P(EngineServiceImpl2PartiesTest, RunExecutionPlan) { } } +// control the time returned by the Report +class MockWaitReportServiceImpl : public services::pb::MockReportService { + public: + void Report(::google::protobuf::RpcController* controller, + const pb::ReportRequest* request, + services::pb::MockResponse* response, + ::google::protobuf::Closure* done) override { + brpc::ClosureGuard done_guard(done); + // push to alice and bob + output_chan.Push(1); + output_chan.Push(1); + // pull from alice and bob + input_chan.Pop(); + input_chan.Pop(); + } + + public: + util::SimpleChannel input_chan{1}; + util::SimpleChannel output_chan{1}; +}; + +// run the case: find persons who both exist in ta(Table A) and tb(Table B). +TEST_P(EngineServiceImpl2PartiesTest, SessionNegotiation) { + // Given + auto test_case = GetParam(); + auto session = PrepareTableInMemory( + test_case, "file:runsql_test?mode=memory&cache=shared"); + + // When + auto proc = [&](EngineServiceImpl* svc, pb::RunExecutionPlanRequest* request, + pb::RunExecutionPlanResponse* response) { + EXPECT_NO_THROW( + svc->RunExecutionPlan(&global_cntl, request, response, nullptr)); + EXPECT_EQ(pb::Code::OK, response->status().code()); + }; + + // start mock report service. + brpc::Server recv_server; + MockWaitReportServiceImpl service; + ASSERT_EQ(0, + recv_server.AddService(&service, brpc::SERVER_DOESNT_OWN_SERVICE)); + brpc::ServerOptions recv_options; + ASSERT_EQ(0, recv_server.Start("127.0.0.1:0", &recv_options)); + + pb::RunExecutionPlanResponse response_alice; + pb::RunExecutionPlanRequest request_alice = ConstructRequestForAlice(servers); + // modify request + request_alice.set_async(true); + request_alice.set_callback_url( + fmt::format("{}/MockReportService/Report", + butil::endpoint2str(recv_server.listen_address()).c_str())); + auto* alice_job_params = request_alice.mutable_job_params(); + auto* alice_psi_conf = alice_job_params->mutable_psi_cfg(); + alice_psi_conf->set_psi_curve_type( + static_cast(psi::CurveType::CURVE_25519)); + alice_psi_conf->set_unbalance_psi_larger_party_rows_count_threshold(10000); + auto future_alice = + std::async(proc, engine_svcs[0].get(), &request_alice, &response_alice); + + pb::RunExecutionPlanResponse response_bob; + pb::RunExecutionPlanRequest request_bob = ConstructRequestForBob(servers); + request_bob.set_async(true); + // only check alice result + request_bob.set_callback_url( + fmt::format("{}/MockReportService/Report", + butil::endpoint2str(recv_server.listen_address()).c_str())); + auto* bob_job_params = request_bob.mutable_job_params(); + auto* bob_psi_conf = bob_job_params->mutable_psi_cfg(); + bob_psi_conf->set_psi_curve_type( + static_cast(psi::CurveType::CURVE_FOURQ)); + bob_psi_conf->set_unbalance_psi_larger_party_rows_count_threshold(100000); + auto future_bob = + std::async(proc, engine_svcs[1].get(), &request_bob, &response_bob); + // Then + EXPECT_NO_THROW(future_alice.get()); + SPDLOG_INFO("out: \n{}", response_alice.DebugString()); + + EXPECT_NO_THROW(future_bob.get()); + SPDLOG_INFO("out: \n{}", response_bob.DebugString()); + service.output_chan.Pop(); // alice wait async run finished. + service.output_chan.Pop(); // bob wait async run finished. + auto* session_alice = engine_svcs[0]->GetSessionManager()->GetSession( + request_alice.job_params().job_id()); + EXPECT_EQ(session_alice->GetSessionOptions().psi_config.psi_curve_type, + psi::CurveType::CURVE_INVALID_TYPE); + EXPECT_EQ(session_alice->GetSessionOptions() + .psi_config.unbalance_psi_larger_party_rows_count_threshold, + 100000); + auto* session_bob = engine_svcs[1]->GetSessionManager()->GetSession( + request_alice.job_params().job_id()); + EXPECT_EQ(session_bob->GetSessionOptions().psi_config.psi_curve_type, + psi::CurveType::CURVE_INVALID_TYPE); + EXPECT_EQ(session_bob->GetSessionOptions() + .psi_config.unbalance_psi_larger_party_rows_count_threshold, + 100000); + service.input_chan.Push(1); + service.input_chan.Push(1); + recv_server.Stop(1000); + recv_server.Join(); +} + /// =========================== /// Test for 2 Parties Implementation /// =========================== From fa08062e80cff6620c90e56b84398d72283a6a0f Mon Sep 17 00:00:00 2001 From: "yuyongqiang.yyq" Date: Mon, 8 Sep 2025 03:34:07 +0000 Subject: [PATCH 2/4] fix comment --- engine/framework/session.cc | 56 ++++++---------------- engine/framework/session_negotiation.proto | 2 - 2 files changed, 15 insertions(+), 43 deletions(-) diff --git a/engine/framework/session.cc b/engine/framework/session.cc index fabe3fa1..bb4b0b87 100644 --- a/engine/framework/session.cc +++ b/engine/framework/session.cc @@ -20,7 +20,6 @@ #include #include -#include "algorithm" #include "arrow/visit_array_inline.h" #include "google/protobuf/util/json_util.h" #include "libspu/core/config.h" @@ -220,41 +219,28 @@ void Session::Negotiate() { if (!negotiation_options.psi_options().use_rr22_low_comm_mode()) { session_opt_.psi_config.low_comm_mode = false; } - SPDLOG_INFO( - "session_opt_: {}, {}", + // choose max unbalance_psi_larger_party_rows_count_threshold + session_opt_.psi_config + .unbalance_psi_larger_party_rows_count_threshold = std::max( session_opt_.psi_config.unbalance_psi_larger_party_rows_count_threshold, negotiation_options.psi_options() .unbalance_psi_larger_party_rows_count_threshold()); - // choose max unbalance_psi_larger_party_rows_count_threshold - if (session_opt_.psi_config - .unbalance_psi_larger_party_rows_count_threshold < - negotiation_options.psi_options() - .unbalance_psi_larger_party_rows_count_threshold()) { - session_opt_.psi_config.unbalance_psi_larger_party_rows_count_threshold = - negotiation_options.psi_options() - .unbalance_psi_larger_party_rows_count_threshold(); - } // choose max unbalance_psi_ratio_threshold - if (session_opt_.psi_config.unbalance_psi_ratio_threshold < - negotiation_options.psi_options().unbalance_psi_ratio_threshold()) { - session_opt_.psi_config.unbalance_psi_ratio_threshold = - negotiation_options.psi_options().unbalance_psi_ratio_threshold(); - } + session_opt_.psi_config.unbalance_psi_ratio_threshold = std::max( + session_opt_.psi_config.unbalance_psi_ratio_threshold, + negotiation_options.psi_options().unbalance_psi_ratio_threshold()); curve_types.push_back(static_cast( negotiation_options.psi_options().psi_curve_type())); // choose min streaming_row_num_threshold - if (static_cast( - session_opt_.streaming_options.streaming_row_num_threshold) > - negotiation_options.streaming_options().streaming_row_num_threshold()) { - session_opt_.streaming_options.streaming_row_num_threshold = - negotiation_options.streaming_options().streaming_row_num_threshold(); - } + session_opt_.streaming_options.streaming_row_num_threshold = + std::min(session_opt_.streaming_options.streaming_row_num_threshold, + static_cast(negotiation_options.streaming_options() + .streaming_row_num_threshold())); // choose min batch_row_num - if (static_cast(session_opt_.streaming_options.batch_row_num) > - negotiation_options.streaming_options().batch_row_num()) { - session_opt_.streaming_options.batch_row_num = - negotiation_options.streaming_options().batch_row_num(); - } + session_opt_.streaming_options.batch_row_num = + std::min(session_opt_.streaming_options.batch_row_num, + static_cast( + negotiation_options.streaming_options().batch_row_num())); } // if the curve type between parties differ, it is considered invalid. if (!std::all_of(curve_types.begin(), curve_types.end(), [&](int curve_type) { @@ -302,21 +288,9 @@ void Session::MergeDeviceSymbolsFrom(const spu::device::SymbolTable& other) { } } +// set batched true and create tmp dir for streaming void Session::EnableStreamingBatched() { session_opt_.streaming_options.batched = true; - size_t data[2] = {session_opt_.streaming_options.batch_row_num, - session_opt_.streaming_options.streaming_row_num_threshold}; - // get checksum from other parties - auto bufs = yacl::link::AllGather( - GetLink(), yacl::ByteContainerView(data, 2 * sizeof(size_t)), - "streaming_options"); - for (const auto& buf : bufs) { - session_opt_.streaming_options.batch_row_num = std::min( - session_opt_.streaming_options.batch_row_num, buf.data()[0]); - session_opt_.streaming_options.streaming_row_num_threshold = - std::min(session_opt_.streaming_options.streaming_row_num_threshold, - buf.data()[1]); - } if (session_opt_.streaming_options.dump_file_dir.empty()) { session_opt_.streaming_options.dump_file_dir = util::CreateDirWithRandSuffix(FLAGS_tmp_file_path, id_); diff --git a/engine/framework/session_negotiation.proto b/engine/framework/session_negotiation.proto index 0f99e074..8a0a8a84 100644 --- a/engine/framework/session_negotiation.proto +++ b/engine/framework/session_negotiation.proto @@ -18,8 +18,6 @@ syntax = "proto3"; package scql.engine.negotiate; -import "api/engine.proto"; - message StreamingOptions { int64 streaming_row_num_threshold = 1; int64 batch_row_num = 2; From a9a0ef3d35adc36ad0b4bf5277d20fa506ea46a3 Mon Sep 17 00:00:00 2001 From: "yuyongqiang.yyq" Date: Mon, 8 Sep 2025 03:46:31 +0000 Subject: [PATCH 3/4] fix test --- engine/services/engine_service_impl_test.cc | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/engine/services/engine_service_impl_test.cc b/engine/services/engine_service_impl_test.cc index 6baa71fb..9fc54666 100644 --- a/engine/services/engine_service_impl_test.cc +++ b/engine/services/engine_service_impl_test.cc @@ -538,10 +538,8 @@ class MockWaitReportServiceImpl : public services::pb::MockReportService { brpc::ClosureGuard done_guard(done); // push to alice and bob output_chan.Push(1); - output_chan.Push(1); // pull from alice and bob input_chan.Pop(); - input_chan.Pop(); } public: @@ -601,14 +599,14 @@ TEST_P(EngineServiceImpl2PartiesTest, SessionNegotiation) { bob_psi_conf->set_unbalance_psi_larger_party_rows_count_threshold(100000); auto future_bob = std::async(proc, engine_svcs[1].get(), &request_bob, &response_bob); + service.output_chan.Pop(); // alice wait async run finished. + service.output_chan.Pop(); // bob wait async run finished. // Then EXPECT_NO_THROW(future_alice.get()); SPDLOG_INFO("out: \n{}", response_alice.DebugString()); EXPECT_NO_THROW(future_bob.get()); SPDLOG_INFO("out: \n{}", response_bob.DebugString()); - service.output_chan.Pop(); // alice wait async run finished. - service.output_chan.Pop(); // bob wait async run finished. auto* session_alice = engine_svcs[0]->GetSessionManager()->GetSession( request_alice.job_params().job_id()); EXPECT_EQ(session_alice->GetSessionOptions().psi_config.psi_curve_type, From 06f4e22d31518e1d27ceaa356ddbbfd956ba0605 Mon Sep 17 00:00:00 2001 From: "yuyongqiang.yyq" Date: Tue, 9 Sep 2025 03:18:35 +0000 Subject: [PATCH 4/4] clean useless comment --- engine/services/engine_service_impl_test.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/engine/services/engine_service_impl_test.cc b/engine/services/engine_service_impl_test.cc index 9fc54666..22d21856 100644 --- a/engine/services/engine_service_impl_test.cc +++ b/engine/services/engine_service_impl_test.cc @@ -547,7 +547,6 @@ class MockWaitReportServiceImpl : public services::pb::MockReportService { util::SimpleChannel output_chan{1}; }; -// run the case: find persons who both exist in ta(Table A) and tb(Table B). TEST_P(EngineServiceImpl2PartiesTest, SessionNegotiation) { // Given auto test_case = GetParam(); @@ -588,7 +587,6 @@ TEST_P(EngineServiceImpl2PartiesTest, SessionNegotiation) { pb::RunExecutionPlanResponse response_bob; pb::RunExecutionPlanRequest request_bob = ConstructRequestForBob(servers); request_bob.set_async(true); - // only check alice result request_bob.set_callback_url( fmt::format("{}/MockReportService/Report", butil::endpoint2str(recv_server.listen_address()).c_str()));