Skip to content

Commit 70bd15e

Browse files
candyzoneliutongxuan
authored andcommitted
[Embedding] Fix build save graph bug when creating partitioned EmbeddingVariable in feature_column API. (#521)
1 parent 125b6ed commit 70bd15e

File tree

4 files changed

+49
-1
lines changed

4 files changed

+49
-1
lines changed

tensorflow/python/feature_column/feature_column.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@
155155
from tensorflow.python.ops import math_ops
156156
from tensorflow.python.ops import nn_ops
157157
from tensorflow.python.ops import parsing_ops
158+
from tensorflow.python.ops import partitioned_variables
158159
from tensorflow.python.ops import resource_variable_ops
159160
from tensorflow.python.ops import sparse_ops
160161
from tensorflow.python.ops import string_ops

tensorflow/python/feature_column/feature_column_v2_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7670,6 +7670,26 @@ def testEmbeddingVariableForSharedEmbeddingColumns(self):
76707670
for j in range(3):
76717671
self.assertAlmostEqual(emb_r[i][j], emb_right[i][j])
76727672

7673+
@test_util.run_deprecated_v1
7674+
def testEmbeddingVariableForSharedEmbeddingColumnsWithPartitionNum(self):
7675+
columns_list=[]
7676+
columns_list.append(fc.categorical_column_with_embedding("col_emb", dtype=dtypes.string, partition_num=4))
7677+
W = fc.shared_embedding_columns(columns_list,
7678+
dimension=3,
7679+
initializer=init_ops.ones_initializer(dtypes.float32),
7680+
shared_embedding_collection_name="xxxxx_shared")
7681+
7682+
ids={}
7683+
ids["col_emb"] = sparse_tensor.SparseTensor(indices=[[0,0],[1,0],[2,0],[3,0],[4,0]], values=["aaaa","bbbbb","ccc","4nn","5b"], dense_shape=[5, 5])
7684+
emb = fc_old.input_layer(ids, W)
7685+
fun = math_ops.multiply(emb, 2.0, name='multiply')
7686+
loss = math_ops.reduce_sum(fun, name='reduce_sum')
7687+
opt = ftrl.FtrlOptimizer(0.1, l1_regularization_strength=2.0, l2_regularization_strength=0.00001)
7688+
g_v = opt.compute_gradients(loss)
7689+
train_op = opt.apply_gradients(g_v)
7690+
init = variables_lib.global_variables_initializer()
7691+
saver = saver_module.Saver()
7692+
76737693
@test_util.run_deprecated_v1
76747694
def test_transform_feature(self):
76757695
a = fc.categorical_column_with_identity(key='aaa', num_buckets=3)

tensorflow/python/ops/embedding_variable_ops_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2281,5 +2281,31 @@ def runTestAdagrad(self, var, g):
22812281
for j in range(0, 3):
22822282
self.assertEqual(emb1[i][j], emb2[i][j])
22832283

2284+
def testEmbeddingVariableForContirbFeatureColumnWithPartitionNum(self):
2285+
print("testEmbeddingVariableForContirbFeatureColumnWithPartitionNum")
2286+
checkpoint_directory = self.get_temp_dir()
2287+
evict = variables.L2WeightEvict(l2_weight_threshold=0.9)
2288+
columns = feature_column.sparse_column_with_embedding(
2289+
column_name="col_emb",
2290+
dtype=dtypes.int64,
2291+
partition_num = 4)
2292+
W = feature_column.embedding_column(sparse_id_column=columns,
2293+
dimension=3,
2294+
initializer=init_ops.ones_initializer(dtypes.float32),
2295+
combiner="mean")
2296+
ids = {}
2297+
ids["col_emb"] = sparse_tensor.SparseTensor(
2298+
indices=[[0,0],[1,0],[2,0],[3,0],[4,0],[5,0]],
2299+
values=math_ops.cast([0,0,0,1,1,2], dtypes.int64),
2300+
dense_shape=[6, 1])
2301+
emb= feature_column_ops.input_from_feature_columns(
2302+
columns_to_tensors=ids, feature_columns=[W])
2303+
fun = math_ops.multiply(emb, 2.0, name='multiply')
2304+
loss = math_ops.reduce_sum(fun, name='reduce_sum')
2305+
opt = ftrl.FtrlOptimizer(0.1, l1_regularization_strength=2.0, l2_regularization_strength=0.00001)
2306+
g_v = opt.compute_gradients(loss)
2307+
train_op = opt.apply_gradients(g_v)
2308+
saver = saver_module.Saver()
2309+
22842310
if __name__ == "__main__":
22852311
googletest.main()

tensorflow/python/training/saver.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -701,7 +701,8 @@ def _build_internal(self,
701701
# not need it, but we'll try looking it up on MetaGraph restore
702702
# since it's in a collection.
703703
params=list(element)
704-
if not isinstance(params[0], kv_variable_ops.DynamicEmbeddingVariable):
704+
if not isinstance(params[0], kv_variable_ops.DynamicEmbeddingVariable) and \
705+
not isinstance(params[0], kv_variable_ops.EmbeddingVariable):
705706
element.as_tensor()
706707
return saver_pb2.SaverDef(
707708
filename_tensor_name=filename_tensor.name,

0 commit comments

Comments
 (0)