diff --git a/xprof/convert/smart_suggestion/BUILD b/xprof/convert/smart_suggestion/BUILD index e004dd552..b2d8fe77b 100644 --- a/xprof/convert/smart_suggestion/BUILD +++ b/xprof/convert/smart_suggestion/BUILD @@ -35,6 +35,7 @@ cc_library( "@com_google_absl//absl/status:statusor", "@org_xprof//plugin/xprof/protobuf:input_pipeline_proto_cc", "@org_xprof//plugin/xprof/protobuf:overview_page_proto_cc", + "@tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -43,6 +44,7 @@ cc_library( hdrs = ["tool_data_provider_impl.h"], deps = [ ":tool_data_provider", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@org_xprof//plugin/xprof/protobuf:input_pipeline_proto_cc", "@org_xprof//plugin/xprof/protobuf:op_stats_proto_cc", @@ -110,6 +112,20 @@ cc_library( ], ) +cc_library( + name = "special_op_rule", + hdrs = ["special_op_rule.h"], + deps = [ + ":signal_provider", + ":smart_suggestion_rule", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@org_xprof//plugin/xprof/protobuf:smart_suggestion_proto_cc", + "@xla//xla/tsl/platform:statusor", + ], +) + cc_library( name = "smart_suggestion_rule_factory", hdrs = ["smart_suggestion_rule_factory.h"], @@ -127,6 +143,7 @@ cc_library( ":input_bound_rule", ":memory_bound_rule", ":smart_suggestion_rule_factory", + ":special_op_rule", ], ) @@ -144,6 +161,20 @@ cc_library( ], ) +cc_test( + name = "special_op_rule_test", + srcs = ["special_op_rule_test.cc"], + deps = [ + ":signal_provider", + ":special_op_rule", + ":tool_data_provider", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest_main", + "@org_xprof//plugin/xprof/protobuf:smart_suggestion_proto_cc", + ], +) + cc_test( name = "host_processing_bound_rule_test", srcs = ["host_processing_bound_rule_test.cc"], diff --git a/xprof/convert/smart_suggestion/all_rules.h b/xprof/convert/smart_suggestion/all_rules.h index 201536298..87d9054f8 100644 --- a/xprof/convert/smart_suggestion/all_rules.h +++ b/xprof/convert/smart_suggestion/all_rules.h @@ -21,6 +21,7 @@ limitations under the License. #include "xprof/convert/smart_suggestion/input_bound_rule.h" #include "xprof/convert/smart_suggestion/memory_bound_rule.h" #include "xprof/convert/smart_suggestion/smart_suggestion_rule_factory.h" +#include "xprof/convert/smart_suggestion/special_op_rule.h" namespace tensorflow { namespace profiler { @@ -32,6 +33,7 @@ inline void RegisterAllRules(SmartSuggestionRuleFactory* f) { f->Register(); f->Register(); f->Register(); + f->Register(); // go/keep-sorted end } diff --git a/xprof/convert/smart_suggestion/data_transfer_bound_rule_test.cc b/xprof/convert/smart_suggestion/data_transfer_bound_rule_test.cc index 0882bc2f3..2aaa5eb3b 100644 --- a/xprof/convert/smart_suggestion/data_transfer_bound_rule_test.cc +++ b/xprof/convert/smart_suggestion/data_transfer_bound_rule_test.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include #include +#include +#include #include "testing/base/public/gmock.h" #include "" @@ -43,6 +45,8 @@ class MockToolDataProvider : public ToolDataProvider { (override)); MOCK_METHOD(absl::StatusOr, GetInputPipelineAnalysisResult, (), (override)); + MOCK_METHOD(absl::StatusOr>, + GetEventTimeFractionEachStep, (const std::string&), (override)); }; TEST(DataTransferBoundRuleTest, MeetsConditions) { diff --git a/xprof/convert/smart_suggestion/host_processing_bound_rule_test.cc b/xprof/convert/smart_suggestion/host_processing_bound_rule_test.cc index c058c9b32..a0728cf18 100644 --- a/xprof/convert/smart_suggestion/host_processing_bound_rule_test.cc +++ b/xprof/convert/smart_suggestion/host_processing_bound_rule_test.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include #include +#include +#include #include "testing/base/public/gmock.h" #include "" @@ -43,6 +45,8 @@ class MockToolDataProvider : public ToolDataProvider { (override)); MOCK_METHOD(absl::StatusOr, GetInputPipelineAnalysisResult, (), (override)); + MOCK_METHOD((absl::StatusOr>), + GetEventTimeFractionEachStep, (const std::string&), (override)); }; TEST(HostProcessingBoundRuleTest, MeetsConditions) { diff --git a/xprof/convert/smart_suggestion/memory_bound_rule_test.cc b/xprof/convert/smart_suggestion/memory_bound_rule_test.cc index 047d6e7dc..a0231916b 100644 --- a/xprof/convert/smart_suggestion/memory_bound_rule_test.cc +++ b/xprof/convert/smart_suggestion/memory_bound_rule_test.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include #include +#include +#include #include "testing/base/public/gmock.h" #include "" @@ -43,6 +45,8 @@ class MockToolDataProvider : public ToolDataProvider { (override)); MOCK_METHOD(absl::StatusOr, GetInputPipelineAnalysisResult, (), (override)); + MOCK_METHOD(absl::StatusOr>, + GetEventTimeFractionEachStep, (const std::string&), (override)); }; TEST(MemoryBoundRuleTest, MeetsConditions) { diff --git a/xprof/convert/smart_suggestion/signal_provider.h b/xprof/convert/smart_suggestion/signal_provider.h index 6d6c9f8c6..f2ba8dd7e 100644 --- a/xprof/convert/smart_suggestion/signal_provider.h +++ b/xprof/convert/smart_suggestion/signal_provider.h @@ -17,6 +17,7 @@ limitations under the License. #define THIRD_PARTY_XPROF_CONVERT_SMART_SUGGESTION_SIGNAL_PROVIDER_H_ #include +#include #include #include "absl/status/statusor.h" @@ -103,6 +104,24 @@ class SignalProvider { return (non_enqueue_us / total_input_time_us) * 100.0; } + // Returns the average percentage of step time for a given event name. + absl::StatusOr GetAvgEventTimePercent( + const std::string& event_name) const { + TF_ASSIGN_OR_RETURN( + auto event_time_of_interest, + tool_data_provider_->GetEventTimeFractionEachStep(event_name)); + + double total_percent = 0; + for (float event_percent : event_time_of_interest) { + total_percent += event_percent; + } + + if (event_time_of_interest.empty()) { + return 0.0; + } + return (total_percent / event_time_of_interest.size()) * 100.0; + } + private: std::unique_ptr tool_data_provider_; }; diff --git a/xprof/convert/smart_suggestion/special_op_rule.h b/xprof/convert/smart_suggestion/special_op_rule.h new file mode 100644 index 000000000..0f2c39d79 --- /dev/null +++ b/xprof/convert/smart_suggestion/special_op_rule.h @@ -0,0 +1,84 @@ +#ifndef THIRD_PARTY_XPROF_CONVERT_SMART_SUGGESTION_SPECIAL_OP_RULE_H_ +#define THIRD_PARTY_XPROF_CONVERT_SMART_SUGGESTION_SPECIAL_OP_RULE_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "xla/tsl/platform/statusor.h" +#include "xprof/convert/smart_suggestion/signal_provider.h" +#include "xprof/convert/smart_suggestion/smart_suggestion_rule.h" +#include "plugin/xprof/protobuf/smart_suggestion.pb.h" + +namespace tensorflow { +namespace profiler { + +// The name of the special op we are interested in, by default barrier-cores. +// TODO(zhuruiyang): We will need to update it to support other special ops with +// a vector of op strings. +constexpr char kSpecialOpName[] = "barrier-cores"; +// If the percentage of step time that is due to the special op is higher than +// this threshold, it is considered a bottleneck. +constexpr double kSpecialOpBoundThresholdInPercent = 10; + +// Rule to detect high special op percentage bottleneck. +class SpecialOpRule : public SmartSuggestionRule { + public: + bool MeetsConditions(const SignalProvider& signal_provider) const override { + absl::StatusOr special_op_percent = + signal_provider.GetAvgEventTimePercent(kSpecialOpName); + if (!special_op_percent.ok()) { + return false; + } + + return *special_op_percent >= kSpecialOpBoundThresholdInPercent; + } + + // Generates a suggestion if the special op percentage is above the threshold. + // TODO(zhuruiyang): We will need to update it to support other special ops + // with a vector of op strings. Currently the suggestion text only supports + // barrier-cores. + absl::StatusOr> GenerateSuggestion( + const SignalProvider& signal_provider) const override { + SmartSuggestion suggestion; + suggestion.set_rule_name("SpecialOpRule"); + TF_ASSIGN_OR_RETURN(double special_op_percent, + signal_provider.GetAvgEventTimePercent(kSpecialOpName)); + auto display_name = absl::StrCat("TPU ", kSpecialOpName); + // TODO(zhuruiyang): The current suggestion text is hard-coded for just + // barrier-cores. We will need to update it to support other special ops. + std::string suggestion_text = absl::StrCat( + "

Your program is likely bottlenecked by ", display_name, + " operations: an average of ", + absl::StrFormat("%.1f", special_op_percent), + "% of each step time is spent on these operations. This " + "often indicates a synchronization issue between workers in a " + "distributed training setup. Please consider the following " + "optimizations:

", + "
    " + "
  • Investigate Workload Balance: Check for stragglers, i.e., " + "workers that are significantly slower than others. Uneven workloads " + "can cause faster workers to wait at the barrier.
  • " + "
  • Optimize Collective Operations: Operations like AllReduce " + "involve synchronization. Ensure they are used efficiently. Check " + "the size of data being communicated.
  • " + "
  • Check Network: Network latency or bandwidth can be a " + "bottleneck for distributed operations, causing workers to wait " + "longer at barriers.
  • " + "
  • Improve Data Input Pipeline: Ensure your data loading and " + "preprocessing pipeline is efficient and balanced across all " + "workers. A slow input pipeline on one worker can stall all " + "others.
  • " + "
"); + + suggestion.set_suggestion_text(suggestion_text); + return suggestion; + } +}; + +} // namespace profiler +} // namespace tensorflow + +#endif // THIRD_PARTY_XPROF_CONVERT_SMART_SUGGESTION_SPECIAL_OP_RULE_H_ diff --git a/xprof/convert/smart_suggestion/special_op_rule_test.cc b/xprof/convert/smart_suggestion/special_op_rule_test.cc new file mode 100644 index 000000000..e5e0f917a --- /dev/null +++ b/xprof/convert/smart_suggestion/special_op_rule_test.cc @@ -0,0 +1,86 @@ +#include "xprof/convert/smart_suggestion/special_op_rule.h" + +#include +#include +#include +#include +#include + +#include "testing/base/public/gmock.h" +#include "" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "xprof/convert/smart_suggestion/signal_provider.h" +#include "xprof/convert/smart_suggestion/tool_data_provider.h" +#include "plugin/xprof/protobuf/smart_suggestion.pb.h" + +namespace tensorflow { +namespace profiler { +namespace { + +using ::testing::Eq; +using ::testing::HasSubstr; +using ::testing::Return; +using ::testing::status::IsOkAndHolds; + +// Mock ToolDataProvider +class MockToolDataProvider : public ToolDataProvider { + public: + MOCK_METHOD(absl::StatusOr, GetOverviewPage, (), + (override)); + MOCK_METHOD(absl::StatusOr, + GetInputPipelineAnalysisResult, (), (override)); + MOCK_METHOD(absl::StatusOr>, + GetEventTimeFractionEachStep, (const std::string&), (override)); +}; + +TEST(SpecialOpRuleTest, MeetsConditions) { + auto mock_tool_data_provider = std::make_unique(); + // Average is (0.15+0.25)/2 = 0.2, which is 20%. This is > 10%. + EXPECT_CALL(*mock_tool_data_provider, + GetEventTimeFractionEachStep(kSpecialOpName)) + .WillRepeatedly(Return(std::vector{0.15, 0.25})); + + SignalProvider signal_provider(std::move(mock_tool_data_provider)); + SpecialOpRule rule; + + absl::StatusOr> suggestion = + rule.Apply(signal_provider); + EXPECT_THAT(suggestion, IsOkAndHolds(testing::Not(Eq(std::nullopt)))); + EXPECT_EQ((*suggestion)->rule_name(), "SpecialOpRule"); + EXPECT_THAT((*suggestion)->suggestion_text(), + HasSubstr("20.0% of each step time")); +} + +TEST(SpecialOpRuleTest, NotSpecialOpBound) { + auto mock_tool_data_provider = std::make_unique(); + // Average is (0.01+0.02)/2 = 0.015, which is 1.5%. This is < 10%. + EXPECT_CALL(*mock_tool_data_provider, + GetEventTimeFractionEachStep(kSpecialOpName)) + .WillRepeatedly(Return(std::vector{0.01, 0.02})); + + SignalProvider signal_provider(std::move(mock_tool_data_provider)); + SpecialOpRule rule; + + absl::StatusOr> suggestion = + rule.Apply(signal_provider); + EXPECT_THAT(suggestion, IsOkAndHolds(Eq(std::nullopt))); +} + +TEST(SpecialOpRuleTest, ErrorFetchingPercentile) { + auto mock_tool_data_provider = std::make_unique(); + EXPECT_CALL(*mock_tool_data_provider, + GetEventTimeFractionEachStep(kSpecialOpName)) + .WillRepeatedly(Return(absl::InternalError("Failed to get percentile"))); + + SignalProvider signal_provider(std::move(mock_tool_data_provider)); + SpecialOpRule rule; + + absl::StatusOr> suggestion = + rule.Apply(signal_provider); + EXPECT_THAT(suggestion, IsOkAndHolds(Eq(std::nullopt))); +} + +} // namespace +} // namespace profiler +} // namespace tensorflow diff --git a/xprof/convert/smart_suggestion/tool_data_provider.h b/xprof/convert/smart_suggestion/tool_data_provider.h index 682dd9f63..231cc8afd 100644 --- a/xprof/convert/smart_suggestion/tool_data_provider.h +++ b/xprof/convert/smart_suggestion/tool_data_provider.h @@ -16,9 +16,13 @@ limitations under the License. #ifndef THIRD_PARTY_XPROF_CONVERT_SMART_SUGGESTION_TOOL_DATA_PROVIDER_H_ #define THIRD_PARTY_XPROF_CONVERT_SMART_SUGGESTION_TOOL_DATA_PROVIDER_H_ +#include +#include + #include "absl/status/statusor.h" #include "plugin/xprof/protobuf/input_pipeline.pb.h" #include "plugin/xprof/protobuf/overview_page.pb.h" +#include "tsl/profiler/protobuf/xplane.pb.h" namespace tensorflow { namespace profiler { @@ -34,6 +38,12 @@ class ToolDataProvider { // Returns the InputPipelineAnalysisResult data. virtual absl::StatusOr GetInputPipelineAnalysisResult() = 0; + + // Returns the event time fraction of each step for a given event name. + // The key is the plane name, and the value is a vector of event time + // fractions for each step. + virtual absl::StatusOr> + GetEventTimeFractionEachStep(const std::string& target_event_name) = 0; }; } // namespace profiler diff --git a/xprof/convert/smart_suggestion/tool_data_provider_impl.h b/xprof/convert/smart_suggestion/tool_data_provider_impl.h index 463c2745c..2689c5205 100644 --- a/xprof/convert/smart_suggestion/tool_data_provider_impl.h +++ b/xprof/convert/smart_suggestion/tool_data_provider_impl.h @@ -18,8 +18,11 @@ limitations under the License. #include #include +#include +#include #include "absl/status/statusor.h" +#include "absl/status/status.h" #include "xla/tsl/platform/errors.h" #include "xprof/convert/multi_xplanes_to_op_stats.h" #include "xprof/convert/op_stats_to_input_pipeline_analysis.h" @@ -67,6 +70,12 @@ class ToolDataProviderImpl : public ToolDataProvider { return input_pipeline_analysis_cache_.get(); } + absl::StatusOr> + GetEventTimeFractionEachStep + (const std::string& target_event_name) override { + return absl::UnimplementedError("Not implemented yet."); + } + private: const SessionSnapshot& session_snapshot_; std::unique_ptr overview_page_cache_;