Skip to content

Commit 5a280d3

Browse files
muditgokhale2copybara-github
authored andcommitted
Add all_hosts information to the session_snapshot and move the device collision logic for trace_viewer to CreateTraceEventsContainer.
PiperOrigin-RevId: 822654197
1 parent 1697228 commit 5a280d3

19 files changed

+171
-101
lines changed

plugin/xprof/convert/raw_to_tool_data.py

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,13 @@ def process_raw_trace(raw_trace):
4141
return ''.join(trace_events_json.TraceEventsJsonStream(trace))
4242

4343

44-
def xspace_to_tools_data_from_byte_string(xspace_byte_list, filenames, tool,
45-
params):
44+
def xspace_to_tools_data_from_byte_string(xspace_byte_list, all_hosts,
45+
filenames, tool, params):
4646
"""Helper function for getting an XSpace tool from a bytes string.
4747
4848
Args:
4949
xspace_byte_list: A list of byte strings read from a XSpace proto file.
50+
all_hosts: A list of all hosts in the session.
5051
filenames: Names of the read files.
5152
tool: A string of tool name.
5253
params: user input parameters.
@@ -57,7 +58,7 @@ def xspace_to_tools_data_from_byte_string(xspace_byte_list, filenames, tool,
5758
# pylint:disable=dangerous-default-value
5859
def xspace_wrapper_func(xspace_arg, tool_arg, params={}):
5960
return _pywrap_profiler_plugin.xspace_to_tools_data_from_byte_string(
60-
xspace_arg, filenames, tool_arg, params)
61+
xspace_arg, all_hosts, filenames, tool_arg, params)
6162
# pylint:enable=dangerous-default-value
6263

6364
return xspace_to_tool_data(xspace_byte_list, tool, params,
@@ -73,22 +74,26 @@ def xspace_to_tool_names(xspace_paths):
7374
Returns:
7475
Returns a list of tool names.
7576
"""
77+
# xspace_to_tools_data expects all_hosts as the second argument, passing an
78+
# empty list.
7679
raw_data, success = _pywrap_profiler_plugin.xspace_to_tools_data(
77-
xspace_paths, 'tool_names')
80+
xspace_paths, [], 'tool_names', {})
7881
if success:
7982
return [tool for tool in raw_data.decode().split(',')]
8083
return []
8184

8285

8386
def xspace_to_tool_data(
8487
xspace_paths,
88+
all_hosts,
8589
tool,
8690
params,
8791
xspace_wrapper_func=_pywrap_profiler_plugin.xspace_to_tools_data):
8892
"""Converts XSpace to tool data string.
8993
9094
Args:
9195
xspace_paths: A list of XSpace paths.
96+
all_hosts: A list of all hosts in the session.
9297
tool: A string of tool name.
9398
params: user input parameters.
9499
xspace_wrapper_func: A callable that takes a list of strings and a tool and
@@ -112,26 +117,31 @@ def xspace_to_tool_data(
112117
if tool == 'trace_viewer':
113118
# Trace viewer handles one host at a time.
114119
assert len(xspace_paths) == 1
115-
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
120+
raw_data, success = xspace_wrapper_func(
121+
xspace_paths, all_hosts, tool, options)
116122
if success:
117123
data = process_raw_trace(raw_data)
118124
elif tool == 'trace_viewer@':
119125
options = params.get('trace_viewer_options', {})
120126
options['use_saved_result'] = params.get('use_saved_result', True)
121-
options['hosts'] = params.get('hosts', [])
122-
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
127+
options['hosts'] = all_hosts
128+
raw_data, success = xspace_wrapper_func(
129+
xspace_paths, all_hosts, tool, options)
123130
if success:
124131
data = raw_data
125132
elif tool == 'overview_page':
126-
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
133+
json_data, success = xspace_wrapper_func(
134+
xspace_paths, all_hosts, tool, options)
127135
if success:
128136
data = json_data
129137
elif tool == 'input_pipeline_analyzer':
130-
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
138+
json_data, success = xspace_wrapper_func(
139+
xspace_paths, all_hosts, tool, options)
131140
if success:
132141
data = json_data
133142
elif tool == 'framework_op_stats':
134-
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
143+
json_data, success = xspace_wrapper_func(
144+
xspace_paths, all_hosts, tool, options)
135145
if success:
136146
if tqx == 'out:csv':
137147
data = csv_writer.json_to_csv(json_data)
@@ -142,15 +152,16 @@ def xspace_to_tool_data(
142152
# TODO(b/419013992): Remove this tool completely as it has been deprecated
143153
legacy_tool = 'tensorflow_stats'
144154
json_data, success = xspace_wrapper_func(
145-
xspace_paths, legacy_tool, options
155+
xspace_paths, all_hosts, legacy_tool, options
146156
)
147157
if success:
148158
if tqx == 'out:csv':
149159
data = csv_writer.json_to_csv(json_data)
150160
else:
151161
data = json_data
152162
elif tool == 'kernel_stats':
153-
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
163+
json_data, success = xspace_wrapper_func(
164+
xspace_paths, all_hosts, tool, options)
154165
if success:
155166
if tqx == 'out:csv':
156167
data = csv_writer.json_to_csv(json_data)
@@ -159,37 +170,44 @@ def xspace_to_tool_data(
159170
elif tool == 'memory_profile':
160171
# Memory profile handles one host at a time.
161172
assert len(xspace_paths) == 1
162-
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
173+
raw_data, success = xspace_wrapper_func(
174+
xspace_paths, all_hosts, tool, options)
163175
if success:
164176
data = raw_data
165177
elif tool == 'pod_viewer':
166-
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
178+
raw_data, success = xspace_wrapper_func(
179+
xspace_paths, all_hosts, tool, options)
167180
if success:
168181
data = raw_data
169182
elif tool == 'op_profile':
170183
options['group_by'] = params.get('group_by', 'program')
171-
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
184+
raw_data, success = xspace_wrapper_func(
185+
xspace_paths, all_hosts, tool, options)
172186
if success:
173187
data = raw_data
174188
elif tool == 'hlo_op_profile':
175189
options['group_by'] = params.get('group_by', 'program')
176-
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
190+
raw_data, success = xspace_wrapper_func(
191+
xspace_paths, all_hosts, tool, options)
177192
if success:
178193
data = raw_data
179194
elif tool == 'hlo_stats':
180-
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
195+
json_data, success = xspace_wrapper_func(
196+
xspace_paths, all_hosts, tool, options)
181197
if success:
182198
data = json_data
183199
elif tool == 'roofline_model':
184-
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
200+
json_data, success = xspace_wrapper_func(
201+
xspace_paths, all_hosts, tool, options)
185202
if success:
186203
data = json_data
187204
elif tool == 'graph_viewer':
188205
download_hlo_types = ['pb', 'pbtxt', 'json', 'short_txt', 'long_txt']
189206
graph_html_type = 'graph'
190207
options = params.get('graph_viewer_options', {})
191208
options['use_saved_result'] = params.get('use_saved_result', True)
192-
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
209+
raw_data, success = xspace_wrapper_func(
210+
xspace_paths, all_hosts, tool, options)
193211
if success:
194212
data = raw_data
195213
content_type = 'text/plain'
@@ -213,18 +231,21 @@ def xspace_to_tool_data(
213231
'view_memory_allocation_timeline': view_memory_allocation_timeline,
214232
'memory_space': params.get('memory_space', ''),
215233
}
216-
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
234+
raw_data, success = xspace_wrapper_func(
235+
xspace_paths, all_hosts, tool, options)
217236
if success:
218237
data = raw_data
219238
if view_memory_allocation_timeline:
220239
content_type = 'text/html'
221240
elif tool == 'megascale_stats':
222241
options = {'host_name': params.get('host')}
223-
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
242+
json_data, success = xspace_wrapper_func(
243+
xspace_paths, all_hosts, tool, options)
224244
if success:
225245
data = json_data
226246
elif tool == 'inference_profile':
227-
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
247+
json_data, success = xspace_wrapper_func(
248+
xspace_paths, all_hosts, tool, options)
228249
if success:
229250
data = json_data
230251
else:

plugin/xprof/convert/raw_to_tool_data_test.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@ def test_using_old_tool_format_maps_to_new_format(self):
2727
xspace_paths=["/path/to/xspace"],
2828
tool="trace_viewer@^",
2929
params={},
30-
xspace_wrapper_func=lambda paths, tool, options: (tool.encode(), True),
30+
all_hosts=[],
31+
xspace_wrapper_func=lambda paths, hosts, tool, options: (
32+
tool.encode(),
33+
True,
34+
),
3135
)
3236

3337
self.assertEqual(data, b"trace_viewer@")
@@ -38,7 +42,11 @@ def test_using_new_tool_format_does_not_map_to_old_format(self):
3842
xspace_paths=["/path/to/xspace"],
3943
tool="trace_viewer@",
4044
params={},
41-
xspace_wrapper_func=lambda paths, tool, options: (tool.encode(), True),
45+
all_hosts=[],
46+
xspace_wrapper_func=lambda paths, hosts, tool, options: (
47+
tool.encode(),
48+
True,
49+
),
4250
)
4351

4452
self.assertEqual(data, b"trace_viewer@")

plugin/xprof/integration_tests/tpu/tensorflow/tpu_tf2_keras_test.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def test_tools_are_in_list(self):
114114

115115
def test_overview_page(self):
116116
xspace_filenames = self._get_session_snapshot()
117-
result, _ = raw_to_tool_data.xspace_to_tool_data(xspace_filenames,
117+
result, _ = raw_to_tool_data.xspace_to_tool_data(xspace_filenames, [],
118118
'overview_page', {})
119119
result = json.loads(result)
120120
run_environment = result[2]
@@ -123,7 +123,9 @@ def test_overview_page(self):
123123

124124
def test_overview_page_creates_cache(self):
125125
xspace_filenames = self._get_session_snapshot()
126-
raw_to_tool_data.xspace_to_tool_data(xspace_filenames, 'overview_page', {})
126+
raw_to_tool_data.xspace_to_tool_data(
127+
xspace_filenames, [], 'overview_page', {}
128+
)
127129
profile_plugin_root = os.path.join(log_dir, 'plugins/profile')
128130
# The session exists under a director whose name is time-dependent.
129131
cache_glob = os.path.join(profile_plugin_root, '*', '*.op_stats.pb')
@@ -132,7 +134,7 @@ def test_overview_page_creates_cache(self):
132134
def test_op_profile(self):
133135
xspace_filenames = self._get_session_snapshot()
134136
result, _ = raw_to_tool_data.xspace_to_tool_data(
135-
xspace_filenames, 'op_profile', {'group_by': 'category'}
137+
xspace_filenames, [], 'op_profile', {'group_by': 'category'}
136138
)
137139
result = json.loads(result)
138140
logging.info(result)
@@ -151,7 +153,7 @@ def test_op_profile(self):
151153
def test_device_trace_contains_threads(self):
152154
xspace_filenames = self._get_session_snapshot()
153155
result, _ = raw_to_tool_data.xspace_to_tool_data(
154-
xspace_filenames, 'trace_viewer', {}
156+
xspace_filenames, [], 'trace_viewer', {}
155157
)
156158
result = json.loads(result)
157159
thread_names = []

plugin/xprof/profile_plugin.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -707,7 +707,7 @@ def hlo_module_list_route(
707707

708708
def _get_valid_hosts(
709709
self, run_dir: str, run: str, tool: str, hosts_param: str, host: str
710-
) -> tuple[List[str], List[epath.Path]]:
710+
) -> tuple[List[str], List[epath.Path], List[str]]:
711711
"""Retrieves and validates the hosts and asset paths for a run and tool.
712712
713713
Args:
@@ -718,7 +718,7 @@ def _get_valid_hosts(
718718
host: The single host parameter.
719719
720720
Returns:
721-
A tuple containing (selected_hosts, asset_paths).
721+
A tuple containing (selected_hosts, asset_paths, all_hosts).
722722
723723
Raises:
724724
FileNotFoundError: If a required xplane file for the specified host(s)
@@ -781,7 +781,9 @@ def _get_valid_hosts(
781781
'Host must be specified for tool %s in run %s' % (tool, run)
782782
)
783783

784-
return selected_hosts, asset_paths
784+
all_hosts = list(all_xplane_files.keys())
785+
786+
return selected_hosts, asset_paths, all_hosts
785787

786788
def data_impl(
787789
self, request: wrappers.Request
@@ -870,7 +872,7 @@ def data_impl(
870872

871873
_, content_encoding = None, None
872874
if use_xplane(tool):
873-
selected_hosts, asset_paths = self._get_valid_hosts(
875+
selected_hosts, asset_paths, all_hosts = self._get_valid_hosts(
874876
run_dir, run, tool, hosts_param, host
875877
)
876878
if not asset_paths:
@@ -879,7 +881,7 @@ def data_impl(
879881
params['hosts'] = selected_hosts
880882
try:
881883
data, content_type = convert.xspace_to_tool_data(
882-
asset_paths, tool, params)
884+
asset_paths, all_hosts, tool, params)
883885
except AttributeError as e:
884886
logger.warning('Error generating analysis results due to %s', e)
885887
raise AttributeError(

plugin/xprof/profile_plugin_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ def testDataImplTraceViewerOptions(self, mock_xspace_to_tool_data):
465465
)
466466

467467
mock_xspace_to_tool_data.assert_called_once_with(
468-
[mock.ANY], 'trace_viewer@', expected_params
468+
[mock.ANY], ['host0', 'host1'], 'trace_viewer@', expected_params
469469
)
470470
args, _ = mock_xspace_to_tool_data.call_args
471471
actual_path_list = args[0]

xprof/convert/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ cc_library(
191191
":repository",
192192
":tool_options",
193193
":xplane_to_trace_container",
194+
"@com_google_absl//absl/container:flat_hash_map",
194195
"@com_google_absl//absl/log",
195196
"@com_google_absl//absl/status",
196197
"@com_google_absl//absl/status:statusor",

xprof/convert/repository.cc

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ static auto* kHostDataSuffixes =
5858

5959
absl::StatusOr<SessionSnapshot> SessionSnapshot::Create(
6060
std::vector<std::string> xspace_paths,
61-
std::optional<std::vector<std::unique_ptr<XSpace>>> xspaces) {
61+
std::optional<std::vector<std::unique_ptr<XSpace>>> xspaces,
62+
std::optional<std::vector<std::string>> all_hosts) {
6263
if (xspace_paths.empty()) {
6364
return absl::InvalidArgumentError("Can not find XSpace path.");
6465
}
@@ -85,7 +86,26 @@ absl::StatusOr<SessionSnapshot> SessionSnapshot::Create(
8586
}
8687
}
8788

88-
return SessionSnapshot(std::move(xspace_paths), std::move(xspaces));
89+
return SessionSnapshot(std::move(xspace_paths), std::move(xspaces),
90+
std::move(all_hosts));
91+
}
92+
93+
SessionSnapshot::SessionSnapshot(
94+
std::vector<std::string> xspace_paths,
95+
std::optional<std::vector<std::unique_ptr<XSpace>>> xspaces,
96+
std::optional<std::vector<std::string>> all_hosts)
97+
: xspace_paths_(std::move(xspace_paths)),
98+
all_hosts_(std::move(all_hosts)),
99+
// If the snapshot was initialized by xspaces, the file path and run dir
100+
// is a path tensorflow can't read from or write to so any file IO
101+
// encapsulated in this class will be disabled in this mode.
102+
has_accessible_run_dir_(!xspaces.has_value()),
103+
xspaces_(std::move(xspaces)) {
104+
session_run_dir_ = tsl::io::Dirname(xspace_paths_.at(0));
105+
for (size_t i = 0; i < xspace_paths_.size(); ++i) {
106+
std::string host_name = GetHostname(i);
107+
hostname_map_[host_name] = i;
108+
}
89109
}
90110

91111
absl::StatusOr<XSpace*> SessionSnapshot::GetXSpace(size_t index,
@@ -126,6 +146,10 @@ std::string SessionSnapshot::GetHostname(size_t index) const {
126146
return GetHostnameByPath(xspace_paths_.at(index));
127147
}
128148

149+
std::optional<std::vector<std::string>> SessionSnapshot::GetAllHosts() const {
150+
return all_hosts_;
151+
}
152+
129153
std::optional<std::string> SessionSnapshot::GetFilePath(
130154
absl::string_view toolname, absl::string_view hostname) const {
131155
if (!has_accessible_run_dir_) return std::nullopt;

0 commit comments

Comments
 (0)