@@ -169,8 +169,8 @@ def while_loop_body(iteration, matrix, inactive, old_inactive):
169
169
del old_inactive # Needed by the condition, but not the body.
170
170
iteration += 1
171
171
scale = (1.0 - standard_ops .reduce_sum (
172
- matrix , axis = 0 , keep_dims = True )) / standard_ops .maximum (
173
- 1.0 , standard_ops .reduce_sum (inactive , axis = 0 , keep_dims = True ))
172
+ matrix , axis = 0 , keepdims = True )) / standard_ops .maximum (
173
+ 1.0 , standard_ops .reduce_sum (inactive , axis = 0 , keepdims = True ))
174
174
matrix += scale * inactive
175
175
new_inactive = standard_ops .to_float (matrix > 0 )
176
176
matrix *= new_inactive
@@ -206,10 +206,10 @@ def _project_log_stochastic_matrix_wrt_kl_divergence(log_matrix):
206
206
207
207
# For numerical reasons, make sure that the largest matrix element is zero
208
208
# before exponentiating.
209
- log_matrix -= standard_ops .reduce_max (log_matrix , axis = 0 , keep_dims = True )
209
+ log_matrix -= standard_ops .reduce_max (log_matrix , axis = 0 , keepdims = True )
210
210
log_matrix -= standard_ops .log (
211
211
standard_ops .reduce_sum (
212
- standard_ops .exp (log_matrix ), axis = 0 , keep_dims = True ))
212
+ standard_ops .exp (log_matrix ), axis = 0 , keepdims = True ))
213
213
return log_matrix
214
214
215
215
0 commit comments