Skip to content

Commit b458040

Browse files
author
Patrick Nguyen
committed
Merge commit for internal changes
2 parents a3ef451 + 5c26ec2 commit b458040

File tree

94 files changed

+4418
-1899
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

94 files changed

+4418
-1899
lines changed

WORKSPACE

+4-20
Original file line numberDiff line numberDiff line change
@@ -22,26 +22,10 @@ check_bazel_version_at_least("0.10.0")
2222

2323
load("//tensorflow:workspace.bzl", "tf_workspace")
2424

25-
# Uncomment and update the paths in these entries to build the Android demo.
26-
#android_sdk_repository(
27-
# name = "androidsdk",
28-
# api_level = 23,
29-
# # Ensure that you have the build_tools_version below installed in the
30-
# # SDK manager as it updates periodically.
31-
# build_tools_version = "26.0.1",
32-
# # Replace with path to Android SDK on your system
33-
# path = "<PATH_TO_SDK>",
34-
#)
35-
#
36-
#android_ndk_repository(
37-
# name="androidndk",
38-
# path="<PATH_TO_NDK>",
39-
# # This needs to be 14 or higher to compile TensorFlow.
40-
# # Please specify API level >= 21 to build for 64-bit architecture
41-
# # otherwise the Android NDK will automatically select the latest
42-
# # API level it does support without notice.
43-
# # Note that the NDK version is not the API level.
44-
# api_level=14)
25+
load("//third_party/android:android_configure.bzl", "android_configure")
26+
android_configure(name="local_config_android")
27+
load("@local_config_android//:android.bzl", "android_workspace")
28+
android_workspace()
4529

4630
# Please add all new TensorFlow dependencies in workspace.bzl.
4731
tf_workspace()

configure.py

+29-65
Original file line numberDiff line numberDiff line change
@@ -670,8 +670,9 @@ def valid_ndk_path(path):
670670
error_msg=('The path %s or its child file "source.properties" '
671671
'does not exist.')
672672
)
673-
674-
write_android_ndk_workspace_rule(android_ndk_home_path)
673+
write_action_env_to_bazelrc('ANDROID_NDK_HOME', android_ndk_home_path)
674+
write_action_env_to_bazelrc('ANDROID_NDK_API_LEVEL',
675+
check_ndk_level(android_ndk_home_path))
675676

676677

677678
def create_android_sdk_rule(environ_cp):
@@ -733,41 +734,12 @@ def valid_build_tools(version):
733734
error_msg=('The selected SDK does not have build-tools version %s '
734735
'available.'))
735736

736-
write_android_sdk_workspace_rule(android_sdk_home_path,
737-
android_build_tools_version,
738-
android_api_level)
739-
740-
741-
def write_android_sdk_workspace_rule(android_sdk_home_path,
742-
android_build_tools_version,
743-
android_api_level):
744-
print('Writing android_sdk_workspace rule.\n')
745-
with open(_TF_WORKSPACE, 'a') as f:
746-
f.write("""
747-
android_sdk_repository(
748-
name="androidsdk",
749-
api_level=%s,
750-
path="%s",
751-
build_tools_version="%s")\n
752-
""" % (android_api_level, android_sdk_home_path, android_build_tools_version))
753-
754-
755-
def write_android_ndk_workspace_rule(android_ndk_home_path):
756-
print('Writing android_ndk_workspace rule.')
757-
ndk_api_level = check_ndk_level(android_ndk_home_path)
758-
if int(ndk_api_level) not in _SUPPORTED_ANDROID_NDK_VERSIONS:
759-
print('WARNING: The API level of the NDK in %s is %s, which is not '
760-
'supported by Bazel (officially supported versions: %s). Please use '
761-
'another version. Compiling Android targets may result in confusing '
762-
'errors.\n' % (android_ndk_home_path, ndk_api_level,
763-
_SUPPORTED_ANDROID_NDK_VERSIONS))
764-
with open(_TF_WORKSPACE, 'a') as f:
765-
f.write("""
766-
android_ndk_repository(
767-
name="androidndk",
768-
path="%s",
769-
api_level=%s)\n
770-
""" % (android_ndk_home_path, ndk_api_level))
737+
write_action_env_to_bazelrc('ANDROID_BUILD_TOOLS_VERSION',
738+
android_build_tools_version)
739+
write_action_env_to_bazelrc('ANDROID_SDK_API_LEVEL',
740+
android_api_level)
741+
write_action_env_to_bazelrc('ANDROID_SDK_HOME',
742+
android_sdk_home_path)
771743

772744

773745
def check_ndk_level(android_ndk_home_path):
@@ -780,18 +752,16 @@ def check_ndk_level(android_ndk_home_path):
780752

781753
revision = re.search(r'Pkg.Revision = (\d+)', filedata)
782754
if revision:
783-
return revision.group(1)
784-
return None
785-
786-
787-
def workspace_has_any_android_rule():
788-
"""Check the WORKSPACE for existing android_*_repository rules."""
789-
with open(_TF_WORKSPACE, 'r') as f:
790-
workspace = f.read()
791-
has_any_rule = re.search(r'^android_[ns]dk_repository',
792-
workspace,
793-
re.MULTILINE)
794-
return has_any_rule
755+
ndk_api_level = revision.group(1)
756+
else:
757+
raise Exception('Unable to parse NDK revision.')
758+
if int(ndk_api_level) not in _SUPPORTED_ANDROID_NDK_VERSIONS:
759+
print('WARNING: The API level of the NDK in %s is %s, which is not '
760+
'supported by Bazel (officially supported versions: %s). Please use '
761+
'another version. Compiling Android targets may result in confusing '
762+
'errors.\n' % (android_ndk_home_path, ndk_api_level,
763+
_SUPPORTED_ANDROID_NDK_VERSIONS))
764+
return ndk_api_level
795765

796766

797767
def set_gcc_host_compiler_path(environ_cp):
@@ -1223,7 +1193,7 @@ def set_tf_cuda_compute_capabilities(environ_cp):
12231193
# Check whether all capabilities from the input is valid
12241194
all_valid = True
12251195
# Remove all whitespace characters before splitting the string
1226-
# that users may insert by accident, as this will result in error
1196+
# that users may insert by accident, as this will result in error
12271197
tf_cuda_compute_capabilities = ''.join(tf_cuda_compute_capabilities.split())
12281198
for compute_capability in tf_cuda_compute_capabilities.split(','):
12291199
m = re.match('[0-9]+.[0-9]+', compute_capability)
@@ -1556,21 +1526,15 @@ def main():
15561526
set_build_strip_flag()
15571527
set_windows_build_flags()
15581528

1559-
if workspace_has_any_android_rule():
1560-
print('The WORKSPACE file has at least one of ["android_sdk_repository", '
1561-
'"android_ndk_repository"] already set. Will not ask to help '
1562-
'configure the WORKSPACE. Please delete the existing rules to '
1563-
'activate the helper.\n')
1564-
else:
1565-
if get_var(
1566-
environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace',
1567-
False,
1568-
('Would you like to interactively configure ./WORKSPACE for '
1569-
'Android builds?'),
1570-
'Searching for NDK and SDK installations.',
1571-
'Not configuring the WORKSPACE for Android builds.'):
1572-
create_android_ndk_rule(environ_cp)
1573-
create_android_sdk_rule(environ_cp)
1529+
if get_var(
1530+
environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace',
1531+
False,
1532+
('Would you like to interactively configure ./WORKSPACE for '
1533+
'Android builds?'),
1534+
'Searching for NDK and SDK installations.',
1535+
'Not configuring the WORKSPACE for Android builds.'):
1536+
create_android_ndk_rule(environ_cp)
1537+
create_android_sdk_rule(environ_cp)
15741538

15751539
print('Preconfigured Bazel build configs. You can use any of the below by '
15761540
'adding "--config=<>" to your build command. See tools/bazel.rc for '

tensorflow/compiler/xla/BUILD

-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ xla_proto_library(
5353
deps = [
5454
":xla_data_proto",
5555
"//tensorflow/compiler/xla/service:hlo_proto",
56-
"//tensorflow/compiler/xla/service:session_proto",
5756
],
5857
)
5958

tensorflow/compiler/xla/client/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ cc_library(
110110
"//tensorflow/compiler/xla/service:compiler",
111111
"//tensorflow/compiler/xla/service:device_memory_allocator",
112112
"//tensorflow/compiler/xla/service:executable",
113+
"//tensorflow/compiler/xla/service:hlo_proto",
113114
"//tensorflow/compiler/xla/service:local_service",
114115
"//tensorflow/compiler/xla/service:shaped_buffer",
115116
"//tensorflow/compiler/xla/service:source_map_util",

tensorflow/compiler/xla/client/local_client.cc

+11-11
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
185185
run_options, backend_->StreamBorrower(),
186186
backend_->eigen_intra_op_thread_pool());
187187

188-
if (executable_->dumping()) {
188+
if (executable_->dumping_snapshot()) {
189189
return ExecuteAndDump(&service_options, arguments);
190190
}
191191
return executable_->ExecuteOnStreamWrapper(
@@ -195,36 +195,36 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
195195
StatusOr<ScopedShapedBuffer> LocalExecutable::ExecuteAndDump(
196196
const ServiceExecutableRunOptions* run_options,
197197
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
198-
executable_->session_module()->set_execution_platform(
198+
executable_->hlo_snapshot()->set_execution_platform(
199199
backend_->platform()->Name());
200-
TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->session_module()));
200+
TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->hlo_snapshot()));
201201
TF_ASSIGN_OR_RETURN(
202202
ScopedShapedBuffer result,
203203
executable_->ExecuteOnStream(run_options, arguments,
204204
/*hlo_execution_profile=*/nullptr));
205-
TF_RETURN_IF_ERROR(RecordResult(&result, executable_->session_module()));
206-
TF_RETURN_IF_ERROR(executable_->DumpSessionModule());
205+
TF_RETURN_IF_ERROR(RecordResult(&result, executable_->hlo_snapshot()));
206+
TF_RETURN_IF_ERROR(executable_->DumpHloSnapshot());
207207
return std::move(result);
208208
}
209209

210210
Status LocalExecutable::RecordArguments(
211211
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
212-
SessionModule* session_module) {
213-
session_module->clear_arguments();
212+
HloSnapshot* hlo_snapshot) {
213+
hlo_snapshot->clear_arguments();
214214
for (const ShapedBuffer* argument : arguments) {
215215
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
216216
LiteralFromShapedBuffer(*argument));
217-
*session_module->add_arguments() = literal->ToProto();
217+
*hlo_snapshot->add_arguments() = literal->ToProto();
218218
}
219219
return Status::OK();
220220
}
221221

222222
Status LocalExecutable::RecordResult(const ShapedBuffer* result,
223-
SessionModule* session_module) {
224-
session_module->clear_result();
223+
HloSnapshot* hlo_snapshot) {
224+
hlo_snapshot->clear_result();
225225
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
226226
LiteralFromShapedBuffer(*result));
227-
*session_module->mutable_result() = literal->ToProto();
227+
*hlo_snapshot->mutable_result() = literal->ToProto();
228228
return Status::OK();
229229
}
230230

tensorflow/compiler/xla/client/local_client.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ limitations under the License.
2525
#include "tensorflow/compiler/xla/service/compiler.h"
2626
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
2727
#include "tensorflow/compiler/xla/service/executable.h"
28+
#include "tensorflow/compiler/xla/service/hlo.pb.h"
2829
#include "tensorflow/compiler/xla/service/local_service.h"
2930
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
3031
#include "tensorflow/compiler/xla/statusor.h"
@@ -78,11 +79,10 @@ class LocalExecutable {
7879
// proto.
7980
Status RecordArguments(
8081
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
81-
SessionModule* session_module);
82+
HloSnapshot* hlo_snapshot);
8283

8384
// Records the result of the computation in a SessionModule proto.
84-
Status RecordResult(const ShapedBuffer* result,
85-
SessionModule* session_module);
85+
Status RecordResult(const ShapedBuffer* result, HloSnapshot* hlo_snapshot);
8686

8787
// Returns a literal containing the contents of the given ShapedBuffer.
8888
StatusOr<std::unique_ptr<Literal>> LiteralFromShapedBuffer(

tensorflow/compiler/xla/client/xla_client/xla_builder.cc

+23-1
Original file line numberDiff line numberDiff line change
@@ -1613,13 +1613,35 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
16131613

16141614
XlaOp XlaBuilder::CrossReplicaSum(const XlaOp& operand) {
16151615
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
1616-
HloInstructionProto instr;
1616+
TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
1617+
const Shape& scalar_shape = ShapeUtil::MakeShape(shape.element_type(), {});
1618+
auto b = CreateSubBuilder("sum");
1619+
b->Add(b->Parameter(/*parameter_number=*/0, scalar_shape, "x"),
1620+
b->Parameter(/*parameter_number=*/1, scalar_shape, "y"));
1621+
TF_ASSIGN_OR_RETURN(auto computation, b->Build());
1622+
return CrossReplicaSum(operand, computation, /*replica_group_ids=*/{},
1623+
/*channel_id=*/tensorflow::gtl::nullopt);
1624+
});
1625+
}
1626+
1627+
XlaOp XlaBuilder::CrossReplicaSum(
1628+
const XlaOp& operand, const XlaComputation& computation,
1629+
tensorflow::gtl::ArraySlice<int64> replica_group_ids,
1630+
const tensorflow::gtl::optional<ChannelHandle>& channel_id) {
1631+
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
1632+
if (!replica_group_ids.empty() || channel_id.has_value()) {
1633+
return Unimplemented(
1634+
"replica_group_ids and channel_id and is not supported in AllReduce");
1635+
}
16171636

1637+
HloInstructionProto instr;
16181638
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
16191639
TF_ASSIGN_OR_RETURN(
16201640
*instr.mutable_shape(),
16211641
ShapeInference::InferCrossReplicaSumShape({&operand_shape}));
16221642

1643+
AddCalledComputation(computation, &instr);
1644+
16231645
return AddInstruction(std::move(instr), HloOpcode::kCrossReplicaSum,
16241646
{operand});
16251647
});

tensorflow/compiler/xla/client/xla_client/xla_builder.h

+23
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,29 @@ class XlaBuilder {
532532
// supply one input to the sum and all replicas receive the resulting sum.
533533
XlaOp CrossReplicaSum(const XlaOp& operand);
534534

535+
// Enqueues an operation that do an AllReduce of the operand cross cores. Here
536+
// AllReduce means doing a reduction on the input operand cross cores and then
537+
// broadcasting the reduction result to those cores. The reduction function is
538+
// defined by `computation`, which should be a commutative computation on
539+
// scalars, e.g., add, min, or max. The way that AllReduce is applied is
540+
// configured by:
541+
//
542+
// - `replica_group_ids`: maps replica ids to subgroup ids. If empty, all
543+
// replicas belong to one group. Allreduce will be applied within subgroups.
544+
// For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means,
545+
// replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
546+
//
547+
// - `channel_id`: for Allreduce nodes from different models, if they have the
548+
// same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be
549+
// applied cross models.
550+
//
551+
// TODO(b/79737069): Rename this to AllReduce when it's ready to use.
552+
XlaOp CrossReplicaSum(
553+
const XlaOp& operand, const XlaComputation& computation,
554+
tensorflow::gtl::ArraySlice<int64> replica_group_ids = {},
555+
const tensorflow::gtl::optional<ChannelHandle>& channel_id =
556+
tensorflow::gtl::nullopt);
557+
535558
// Enqueues an operation that scatters the `source` array to the selected
536559
// indices of each window.
537560
XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,

tensorflow/compiler/xla/service/BUILD

-10
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,6 @@ load(
2121
"tf_proto_library_py",
2222
)
2323

24-
xla_proto_library(
25-
name = "session_proto",
26-
srcs = ["session.proto"],
27-
visibility = ["//visibility:public"],
28-
deps = ["//tensorflow/compiler/xla:xla_data_proto"],
29-
)
30-
3124
xla_proto_library(
3225
name = "hlo_proto",
3326
srcs = ["hlo.proto"],
@@ -608,7 +601,6 @@ cc_library(
608601
":hlo_module_config",
609602
":hlo_proto_util",
610603
":platform_util",
611-
":session_proto",
612604
":source_map_util",
613605
":transfer_manager",
614606
":versioned_computation_handle",
@@ -766,7 +758,6 @@ cc_library(
766758
":hlo_graph_dumper",
767759
":hlo_proto",
768760
":pool",
769-
":session_proto",
770761
":shaped_buffer",
771762
":versioned_computation_handle",
772763
"//tensorflow/compiler/xla:executable_run_options",
@@ -870,7 +861,6 @@ cc_library(
870861
hdrs = ["channel_tracker.h"],
871862
deps = [
872863
":hlo",
873-
":session_proto",
874864
":versioned_computation_handle",
875865
"//tensorflow/compiler/xla:status",
876866
"//tensorflow/compiler/xla:status_macros",

tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc

+13-2
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,17 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) {
211211

212212
TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) {
213213
auto builder = HloComputation::Builder(TestName());
214+
215+
auto module = CreateNewModule();
216+
HloComputation::Builder sum_builder("add");
217+
auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter(
218+
/*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x"));
219+
auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter(
220+
/*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "y"));
221+
sum_builder.AddInstruction(HloInstruction::CreateBinary(
222+
ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, x, y));
223+
HloComputation* sum = module->AddEmbeddedComputation(sum_builder.Build());
224+
214225
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
215226
Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
216227

@@ -223,7 +234,8 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) {
223234

224235
HloInstruction* crs =
225236
builder.AddInstruction(HloInstruction::CreateCrossReplicaSum(
226-
ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}));
237+
ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b},
238+
sum));
227239
HloInstruction* gte_a = builder.AddInstruction(
228240
HloInstruction::CreateGetTupleElement(f32_shape, crs, 0));
229241
HloInstruction* gte_b = builder.AddInstruction(
@@ -233,7 +245,6 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) {
233245
HloInstruction* tuple = builder.AddInstruction(
234246
HloInstruction::CreateTuple({gte_a, convert_gte_b}));
235247

236-
auto module = CreateNewModule();
237248
auto computation = module->AddEntryComputation(builder.Build());
238249

239250
EXPECT_TRUE(FoldConversions(module.get()));

0 commit comments

Comments
 (0)